{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-14T07:47:33.908756Z",
     "start_time": "2020-04-14T07:47:33.739793Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.4.0\n",
      "True\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "print(torch.__version__)\n",
    "print(torch.cuda.is_available())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PyTorch简介\n",
    "两个核心特征：\n",
    "- 多维张量，类似`Numpy`但是可以在`GPU`上运行\n",
    "- 用于创建和训练神经网络的自动求导"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 张量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-14T07:47:33.912707Z",
     "start_time": "2020-04-14T07:47:33.909595Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1.3001e+22, 2.4016e-18, 1.8788e+31],\n",
      "        [7.9303e+34, 6.1949e-04, 7.3313e+22],\n",
      "        [7.2151e+22, 2.8404e+29, 2.3089e-12],\n",
      "        [7.1856e+22, 4.3605e+27, 2.5226e-18],\n",
      "        [1.1257e+24, 1.4210e-05, 1.2849e+31]])\n"
     ]
    }
   ],
   "source": [
    "# 未初始化的张量，被创建，其值为该张量被分配的内存处的值\n",
    "x = torch.empty(5, 3)\n",
    "print(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-14T07:47:33.916032Z",
     "start_time": "2020-04-14T07:47:33.913769Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-1.3541,  0.6688, -0.9737],\n",
      "        [ 0.5310,  1.3178,  0.2240],\n",
      "        [ 0.2307, -0.1250, -0.9171],\n",
      "        [-0.6433,  0.0496, -0.2759],\n",
      "        [ 0.1790,  2.0357,  0.2876]])\n"
     ]
    }
   ],
   "source": [
    "# 随机初始化的张量\n",
    "x = torch.randn(5, 3)\n",
    "print(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-14T07:47:33.919024Z",
     "start_time": "2020-04-14T07:47:33.916835Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0, 0, 0],\n",
      "        [0, 0, 0],\n",
      "        [0, 0, 0],\n",
      "        [0, 0, 0],\n",
      "        [0, 0, 0]])\n"
     ]
    }
   ],
   "source": [
    "# 初始化为 0 的张量\n",
    "x = torch.zeros(5, 3, dtype=torch.long)\n",
    "print(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-14T07:47:33.922074Z",
     "start_time": "2020-04-14T07:47:33.919890Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([5.5000, 3.0000])\n"
     ]
    }
   ],
   "source": [
    "# 从数组构建张量\n",
    "x = torch.tensor([5.5, 3])\n",
    "print(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-14T07:47:33.925951Z",
     "start_time": "2020-04-14T07:47:33.922915Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1., 1., 1.],\n",
      "        [1., 1., 1.],\n",
      "        [1., 1., 1.],\n",
      "        [1., 1., 1.],\n",
      "        [1., 1., 1.]], dtype=torch.float64)\n",
      "tensor([[-0.0765,  0.3583, -0.0196],\n",
      "        [-0.1272, -0.1096,  0.5929],\n",
      "        [ 0.3343, -1.5624, -2.0702],\n",
      "        [ 0.2412, -1.1553,  1.0742],\n",
      "        [ 0.3540, -0.7436,  0.0934]])\n"
     ]
    }
   ],
   "source": [
    "# 基于已有张量创建\n",
    "x = x.new_ones(5, 3, dtype=torch.double)\n",
    "print(x)\n",
    "\n",
    "x = torch.randn_like(x, dtype=torch.float)\n",
    "print(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-14T07:47:33.928551Z",
     "start_time": "2020-04-14T07:47:33.926784Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([5, 3])\n"
     ]
    }
   ],
   "source": [
    "print(x.size())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-31T12:59:42.634684Z",
     "start_time": "2020-03-31T12:59:42.629063Z"
    }
   },
   "source": [
    "## 张量操作\n",
    "\n",
    "- `torch.mm`:支持严格意义的矩阵乘法，输入形状必须严格符合矩阵乘法规则\n",
    "- `torch.matmul`:支持广播的矩阵乘法\n",
    "\n",
    "\n",
    "- `tensor.reshape(m,n)`:返回维度为`m*n`的新张量，有时张量数据在内存中的地址未变，有时返回原张量的克隆；改变前后元素个数必须相同\n",
    "- `tensor.resize_(m,n)`:就地改变，数据在内存中的位置未变；改后元素数量少于原张量，索引外元素会丢失；少于原张量，则多出部分元素未初始化;\n",
    "- `tensor.view(m,n)`:返回新的张量，改变前后张量元素数量相同"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-14T07:47:33.931852Z",
     "start_time": "2020-04-14T07:47:33.929359Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 0.0087,  0.5454,  0.2511],\n",
      "        [ 0.1559,  0.5499,  1.5617],\n",
      "        [ 0.4045, -1.0646, -1.6222],\n",
      "        [ 0.5741, -0.3339,  1.6808],\n",
      "        [ 1.0992, -0.1187,  1.0598]])\n"
     ]
    }
   ],
   "source": [
    "# 操作符\n",
    "y = torch.rand(5, 3)\n",
    "print(x + y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-14T07:47:33.935483Z",
     "start_time": "2020-04-14T07:47:33.933228Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 0.0087,  0.5454,  0.2511],\n",
      "        [ 0.1559,  0.5499,  1.5617],\n",
      "        [ 0.4045, -1.0646, -1.6222],\n",
      "        [ 0.5741, -0.3339,  1.6808],\n",
      "        [ 1.0992, -0.1187,  1.0598]])\n"
     ]
    }
   ],
   "source": [
    "# torch函数\n",
    "print(torch.add(x, y))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-14T07:47:33.938980Z",
     "start_time": "2020-04-14T07:47:33.936502Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 0.0087,  0.5454,  0.2511],\n",
      "        [ 0.1559,  0.5499,  1.5617],\n",
      "        [ 0.4045, -1.0646, -1.6222],\n",
      "        [ 0.5741, -0.3339,  1.6808],\n",
      "        [ 1.0992, -0.1187,  1.0598]])\n"
     ]
    }
   ],
   "source": [
    "# 提供变量，接受操作结果为值\n",
    "result = torch.empty(5, 3)\n",
    "torch.add(x, y, out=result)\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-14T07:47:33.941924Z",
     "start_time": "2020-04-14T07:47:33.939699Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 0.0087,  0.5454,  0.2511],\n",
      "        [ 0.1559,  0.5499,  1.5617],\n",
      "        [ 0.4045, -1.0646, -1.6222],\n",
      "        [ 0.5741, -0.3339,  1.6808],\n",
      "        [ 1.0992, -0.1187,  1.0598]])\n"
     ]
    }
   ],
   "source": [
    "# 就地操作，不创建新的变量\n",
    "# 任何就地改变值的操作，以 _ 为后缀\n",
    "y.add_(x)\n",
    "print(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-14T07:47:33.945215Z",
     "start_time": "2020-04-14T07:47:33.942768Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([4, 4]) torch.Size([16]) torch.Size([2, 8])\n"
     ]
    }
   ],
   "source": [
    "# 改变形状\n",
    "x = torch.randn(4, 4)\n",
    "y = x.view(16)\n",
    "z = x.view(-1, 8)\n",
    "print(x.size(), y.size(), z.size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-14T07:47:33.951249Z",
     "start_time": "2020-04-14T07:47:33.946059Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 1.4732332 ,  1.5613918 ,  0.42412254,  0.4733341 ],\n",
       "       [-0.07202843, -0.47456622,  2.0124426 ,  0.52497023],\n",
       "       [ 0.1319841 , -1.8435042 ,  2.5339205 ,  0.9075786 ],\n",
       "       [-0.41401067, -0.46745554,  1.7350237 ,  0.76257217]],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 张量的值\n",
    "x.numpy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 与`Numpy`协作"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-14T07:47:33.954724Z",
     "start_time": "2020-04-14T07:47:33.952080Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([1., 1., 1., 1., 1.])\n",
      "[1. 1. 1. 1. 1.]\n"
     ]
    }
   ],
   "source": [
    "a = torch.ones(5)\n",
    "print(a)\n",
    "b = a.numpy()\n",
    "print(b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-14T07:47:33.957779Z",
     "start_time": "2020-04-14T07:47:33.955420Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([2., 2., 2., 2., 2.])\n",
      "[2. 2. 2. 2. 2.]\n"
     ]
    }
   ],
   "source": [
    "# 改变张量的值，对应的Numpy数组的值也会改变\n",
    "a.add_(1)\n",
    "print(a)\n",
    "print(b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-14T07:47:33.961793Z",
     "start_time": "2020-04-14T07:47:33.958501Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1. 1. 1. 1. 1.] tensor([1., 1., 1., 1., 1.], dtype=torch.float64)\n",
      "[2. 2. 2. 2. 2.] tensor([2., 2., 2., 2., 2.], dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "# 从Numpy数组创建张量，张量与数组共享内存\n",
    "# 两者中任一改变，都会导致另一的改变\n",
    "\n",
    "import numpy as np\n",
    "a = np.ones(5)\n",
    "b = torch.from_numpy(a)\n",
    "print(a, b)\n",
    "np.add(a, 1, out=a)\n",
    "print(a, b)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## CUDA张量\n",
    "- 将张量移动到指定的设备上，张量的`.to`方法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-14T07:47:34.888404Z",
     "start_time": "2020-04-14T07:47:33.962516Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 2.4732,  2.5614,  1.4241,  1.4733],\n",
      "        [ 0.9280,  0.5254,  3.0124,  1.5250],\n",
      "        [ 1.1320, -0.8435,  3.5339,  1.9076],\n",
      "        [ 0.5860,  0.5325,  2.7350,  1.7626]], device='cuda:0')\n",
      "tensor([[ 2.4732,  2.5614,  1.4241,  1.4733],\n",
      "        [ 0.9280,  0.5254,  3.0124,  1.5250],\n",
      "        [ 1.1320, -0.8435,  3.5339,  1.9076],\n",
      "        [ 0.5860,  0.5325,  2.7350,  1.7626]]) torch.float64\n"
     ]
    }
   ],
   "source": [
    "if torch.cuda.is_available():\n",
    "    device = torch.device(\"cuda\")\n",
    "    y = torch.ones_like(x, device=device)  # 指定GPU创建张量\n",
    "    x = x.to(device)  # 将张量移动到设备上\n",
    "    z = x + y\n",
    "    print(z)\n",
    "    print(z.to(\"cpu\"), torch.double)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 利用`Numpy`和`PyTorch`手动实现神经网路"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-14T07:47:35.048918Z",
     "start_time": "2020-04-14T07:47:34.889285Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 25015766.30401914\n",
      "100 790.992865714488\n",
      "200 10.310933129055591\n",
      "300 0.20151679512531337\n",
      "400 0.004273137757451376\n"
     ]
    }
   ],
   "source": [
    "# Numpy实现\n",
    "import numpy as np\n",
    "N, D_in, H, D_out = 64, 1000, 100, 10  # 数据量、输入、隐藏层、输出层神经元\n",
    "\n",
    "x = np.random.randn(N, D_in)  # 训练数据\n",
    "y = np.random.randn(N, D_out)  # 训练标签\n",
    "\n",
    "w1 = np.random.randn(D_in, H)\n",
    "w2 = np.random.randn(H, D_out)\n",
    "\n",
    "learning_rate = 1e-6\n",
    "for t in range(500):\n",
    "    # 前向传播\n",
    "    h = x.dot(w1)\n",
    "    h_relu = np.maximum(h, 0)  # 激活函数\n",
    "    y_pred = h_relu.dot(w2)\n",
    "\n",
    "    # 损失函数\n",
    "    loss = np.square(y_pred - y).sum()\n",
    "    if t % 100 == 0:\n",
    "        print(t, loss)\n",
    "\n",
    "    # 反向传播\n",
    "    grad_y_pred = 2.0 * (y_pred - y)\n",
    "    \n",
    "    grad_w2 = h_relu.T.dot(grad_y_pred)\n",
    "    \n",
    "    grad_h_relu = grad_y_pred.dot(w2.T)\n",
    "    \n",
    "    grad_h = grad_h_relu.copy()\n",
    "    grad_h[h < 0] = 0\n",
    "    \n",
    "    grad_w1 = x.T.dot(grad_h)\n",
    "\n",
    "    # 参数更新\n",
    "    w1 -= learning_rate * grad_w1\n",
    "    w2 -= learning_rate * grad_w2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-14T07:47:35.183113Z",
     "start_time": "2020-04-14T07:47:35.049709Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "99 369.7817077636719\n",
      "199 1.402687907218933\n",
      "299 0.007994033396244049\n",
      "399 0.00018907932098954916\n",
      "499 3.431710138102062e-05\n"
     ]
    }
   ],
   "source": [
    "# PyTorch实现\n",
    "import torch\n",
    "dtype = torch.float\n",
    "device = torch.device(\"cpu\")\n",
    "# device = torch.device(\"cuda:0\")\n",
    "N, D_in, H, D_out = 64, 1000, 100, 10\n",
    "\n",
    "x = torch.randn(N, D_in, device=device, dtype=dtype)\n",
    "y = torch.randn(N, D_out, device=device, dtype=dtype)\n",
    "\n",
    "w1 = torch.randn(D_in, H, device=device, dtype=dtype)\n",
    "w2 = torch.randn(H, D_out, device=device, dtype=dtype)\n",
    "\n",
    "learning_rate = 1e-6\n",
    "for t in range(500):\n",
    "    # 前向传播\n",
    "    h = x.mm(w1)\n",
    "    h_relu = h.clamp(min=0)\n",
    "    y_pred = h_relu.mm(w2)\n",
    "\n",
    "    loss = (y_pred - y).pow(2).sum().item()\n",
    "    if t % 100 == 99:\n",
    "        print(t, loss)\n",
    "        \n",
    "    # 反向传播，手动计算\n",
    "    grad_y_pred = 2.0 * (y_pred - y)\n",
    "    grad_w2 = h_relu.t().mm(grad_y_pred)\n",
    "    grad_h_relu = grad_y_pred.mm(w2.t())\n",
    "    grad_h = grad_h_relu.clone()\n",
    "    grad_h[h < 0] = 0\n",
    "    grad_w1 = x.t().mm(grad_h)\n",
    "\n",
    "    # 更新参数\n",
    "    w1 -= learning_rate * grad_w1\n",
    "    w2 -= learning_rate * grad_w2    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 自动求导(`Autograd`)\n",
    "- 使用`autograd`时，网络的前向传播会定义一个无环的计算图；图中结点是张量，边为函数    \n",
    "- 将`Tensor`的`.requires_grad`设置为`True`，自动追踪作用于张量上的操作，然后调用`.backward()`方法，所有梯度会自动计算出，然后通过`.grad`属性访问\n",
    "- `.detach()`方法将张量从计算历史中删除，阻止继续对其求导\n",
    "    - 也可将代码包裹在`with torch.no_grad():`语句中阻止求导的进行，当模型训练完成，进行评估时很有用\n",
    "- `Function`类，`Tensor`和`Function`互相连通，编码整个计算过程；每个张量有`.grad_fn`属性，指向创建该张量的`Function`\n",
    "    - 用户自定义的张量`grad_fn`属性为`None`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-31T13:35:21.743895Z",
     "start_time": "2020-03-31T13:35:21.737788Z"
    }
   },
   "source": [
    "函数$\\vec{y}=f(\\vec{x})$，张量$\\vec{y}$对张量$\\vec{x}$的梯度是一个雅各布(`Jacobian`)矩阵：\n",
    "$$\\begin{split}J=\\left(\\begin{array}{ccc}\n",
    " \\frac{\\partial y_{1}}{\\partial x_{1}} & \\cdots & \\frac{\\partial y_{1}}{\\partial x_{n}}\\\\\n",
    " \\vdots & \\ddots & \\vdots\\\\\n",
    " \\frac{\\partial y_{m}}{\\partial x_{1}} & \\cdots & \\frac{\\partial y_{m}}{\\partial x_{n}}\n",
    " \\end{array}\\right)\\end{split}$$\n",
    "`torch.autograd`计算向量-雅各布点积，给定向量$v=\\left(\\begin{array}{cccc} v_{1} & v_{2} & \\cdots & v_{m}\\end{array}\\right)^{T}$，计算$v^{T}\\cdot J$，其中$v$恰好是标量函数$l=g\\left(\\vec{y}\\right)$的梯度，即$v=\\left(\\begin{array}{ccc}\\frac{\\partial l}{\\partial y_{1}} & \\cdots & \\frac{\\partial l}{\\partial y_{m}}\\end{array}\\right)^{T}$,根据链式法则，向量-雅各布点积如下：\n",
    "$$\n",
    "\\begin{split}J^{T}\\cdot v=\\left(\\begin{array}{ccc}\n",
    " \\frac{\\partial y_{1}}{\\partial x_{1}} & \\cdots & \\frac{\\partial y_{m}}{\\partial x_{1}}\\\\\n",
    " \\vdots & \\ddots & \\vdots\\\\\n",
    " \\frac{\\partial y_{1}}{\\partial x_{n}} & \\cdots & \\frac{\\partial y_{m}}{\\partial x_{n}}\n",
    " \\end{array}\\right)\\left(\\begin{array}{c}\n",
    " \\frac{\\partial l}{\\partial y_{1}}\\\\\n",
    " \\vdots\\\\\n",
    " \\frac{\\partial l}{\\partial y_{m}}\n",
    " \\end{array}\\right)=\\left(\\begin{array}{c}\n",
    " \\frac{\\partial l}{\\partial x_{1}}\\\\\n",
    " \\vdots\\\\\n",
    " \\frac{\\partial l}{\\partial x_{n}}\n",
    " \\end{array}\\right)\\end{split}\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `autograd`相关操作及属性"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-14T07:47:35.186744Z",
     "start_time": "2020-04-14T07:47:35.183870Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1., 1.],\n",
      "        [1., 1.]], requires_grad=True)\n",
      "tensor([[3., 3.],\n",
      "        [3., 3.]], grad_fn=<AddBackward0>)\n",
      "<AddBackward0 object at 0x7f5dbdb8c090>\n"
     ]
    }
   ],
   "source": [
    "# requires_grad参数\n",
    "x = torch.ones(2, 2, requires_grad=True)\n",
    "print(x)\n",
    "\n",
    "# grad_fn属性\n",
    "y = x + 2\n",
    "print(y)\n",
    "print(y.grad_fn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-14T07:47:35.189888Z",
     "start_time": "2020-04-14T07:47:35.187489Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[27., 27.],\n",
      "        [27., 27.]], grad_fn=<MulBackward0>) tensor(27., grad_fn=<MeanBackward0>)\n"
     ]
    }
   ],
   "source": [
    "# grad_fn属性\n",
    "z = y * y * 3\n",
    "out = z.mean()\n",
    "print(z, out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-14T07:47:35.193327Z",
     "start_time": "2020-04-14T07:47:35.190644Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "False\n",
      "True\n",
      "<SumBackward0 object at 0x7f5db5e3a1d0>\n"
     ]
    }
   ],
   "source": [
    "# requires_grad可以手动改变\n",
    "a = torch.randn(2, 2)\n",
    "a = ((a * 3) / (a - 1))\n",
    "print(a.requires_grad)\n",
    "\n",
    "a.requires_grad_(True)\n",
    "print(a.requires_grad)\n",
    "b = (a * a).sum()\n",
    "print(b.grad_fn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-14T07:47:35.196195Z",
     "start_time": "2020-04-14T07:47:35.194067Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[4.5000, 4.5000],\n",
      "        [4.5000, 4.5000]])\n"
     ]
    }
   ],
   "source": [
    "# 张量的grad属性\n",
    "out.backward()\n",
    "print(x.grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-14T07:47:35.198116Z",
     "start_time": "2020-04-14T07:47:35.196902Z"
    }
   },
   "outputs": [],
   "source": [
    "# autograd.grad 函数手动求导，不要与backward()同时使用\n",
    "# dx = torch.autograd.grad(out, x)\n",
    "# print(dx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-14T07:47:35.202668Z",
     "start_time": "2020-04-14T07:47:35.198867Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([-1023.8810,  1130.8076,   274.0225], grad_fn=<MulBackward0>)\n",
      "tensor([5.1200e+01, 5.1200e+02, 5.1200e-02])\n"
     ]
    }
   ],
   "source": [
    "# 张量的grad属性\n",
    "x = torch.randn(3, requires_grad=True)\n",
    "y = x * 2\n",
    "while y.data.norm() < 1000:\n",
    "    y = y * 2\n",
    "\n",
    "print(y)\n",
    "\n",
    "v = torch.tensor([0.1, 1.0, 0.0001], dtype=torch.float)\n",
    "y.backward(v)\n",
    "print(x.grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 自动求导拟合多项式\n",
    "例如：有来自于$ f(x) = 5x^2 + 3$数据集，需要拟合出该函数$f(x)$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-14T07:47:35.410608Z",
     "start_time": "2020-04-14T07:47:35.203393Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/yangbin7/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:28: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[4.9856863e+00]\n",
      " [1.7926495e-03]\n",
      " [3.8471248e+00]]\n"
     ]
    }
   ],
   "source": [
    "# 数据集\n",
    "def generate_data():\n",
    "    x = torch.rand(100) * 20 - 10\n",
    "    y = 5 * x * x + 3\n",
    "    return x, y\n",
    "\n",
    "\n",
    "def model(x):\n",
    "    f = torch.stack([x * x, x, torch.ones_like(x)], 1)\n",
    "    yhat = torch.squeeze(f @ w, 1)\n",
    "    return yhat\n",
    "\n",
    "\n",
    "def compute_loss(y, yhat):\n",
    "    loss = torch.nn.functional.mse_loss(yhat, y)\n",
    "    return loss\n",
    "\n",
    "\n",
    "def train_step():\n",
    "    x, y = generate_data()\n",
    "    yhat = model(x)\n",
    "    loss = compute_loss(yhat, y)\n",
    "    opt.zero_grad()\n",
    "    loss.backward()\n",
    "    opt.step()\n",
    "\n",
    "\n",
    "w = torch.tensor(torch.randn([3, 1]), requires_grad=True)\n",
    "opt = torch.optim.Adam([w], 0.1)\n",
    "for _ in range(1000):\n",
    "    train_step()\n",
    "    \n",
    "print(w.detach().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 自动求导训练简单的神经网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-14T07:47:35.571384Z",
     "start_time": "2020-04-14T07:47:35.413096Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "99 818.3245849609375\n",
      "199 10.148862838745117\n",
      "299 0.23177969455718994\n",
      "399 0.006615933030843735\n",
      "499 0.0004388819797895849\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "dtype = torch.float\n",
    "device = torch.device(\"cpu\")\n",
    "# device = torch.device(\"cuda:0\")\n",
    "N, D_in, H, D_out = 64, 1000, 100, 10\n",
    "\n",
    "x = torch.randn(N, D_in, device=device, dtype=dtype)\n",
    "y = torch.randn(N, D_out, device=device, dtype=dtype)\n",
    "\n",
    "w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)\n",
    "w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)\n",
    "\n",
    "learning_rate = 1e-6\n",
    "for t in range(500):\n",
    "    y_pred = x.mm(w1).clamp(min=0).mm(w2) # .clamp方法实现relu激活函数\n",
    "    loss = (y_pred - y).pow(2).sum() # 张量先平方，再求和\n",
    "\n",
    "    if t % 100 == 99:\n",
    "        print(t, loss.item())\n",
    "\n",
    "    loss.backward()\n",
    "    with torch.no_grad():\n",
    "        w1 -= learning_rate * w1.grad\n",
    "        w2 -= learning_rate * w2.grad\n",
    "        w1.grad.zero_()\n",
    "        w2.grad.zero_()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 自定义自动求导函数\n",
    "- `torch.autograd.Function`子类，实现`forward`和`backward`方法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-14T07:48:10.817678Z",
     "start_time": "2020-04-14T07:48:10.805424Z"
    }
   },
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "forward() missing 1 required positional argument: 'input'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-29-54b9de2627a9>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     29\u001b[0m \u001b[0mlearning_rate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m1e-6\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     30\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m500\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 31\u001b[0;31m     \u001b[0mrelu\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mMyReLU\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     32\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     33\u001b[0m     \u001b[0my_pred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mw1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mw2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mTypeError\u001b[0m: forward() missing 1 required positional argument: 'input'"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "class MyReLU(torch.autograd.Function):\n",
    "    @staticmethod\n",
    "    def forward(ctx, input):\n",
    "        ctx.save_for_backward(input)\n",
    "        return input.clamp(min=0)\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(ctx, grad_ouput):\n",
    "        input, = ctx.saved_tensors\n",
    "        grad_input = grad_output.clone()\n",
    "        grad_input[input < 0] = 0\n",
    "        return grad_input\n",
    "\n",
    "\n",
    "dtype = torch.float\n",
    "device = torch.device(\"cpu\")\n",
    "\n",
    "N, D_in, H, D_out = 64, 1000, 100, 10\n",
    "\n",
    "x = torch.randn(N, D_in, device=device, dtype=dtype)\n",
    "y = torch.randn(N, D_out, device=device, dtype=dtype)\n",
    "\n",
    "w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)\n",
    "w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)\n",
    "\n",
    "learning_rate = 1e-6\n",
    "for t in range(500):\n",
    "    relu = MyReLU.apply()\n",
    "\n",
    "    y_pred = relu(x.mm(w1)).mm(w2)\n",
    "\n",
    "    loss = (y_pred - y).pow(2).sum()\n",
    "    if t % 100 == 99:\n",
    "        print(t, loss.item())\n",
    "\n",
    "    with torch.no_grad():\n",
    "        w1 -= learning_rate * w1.grad\n",
    "        w2 -= learning_rate * w2.grad\n",
    "        w1.grad.zero_()\n",
    "        w2.grad.zero_()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 神经网络\n",
    "\n",
    "\n",
    "### 创建`Sequential`模型:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-01T07:51:20.139800Z",
     "start_time": "2020-04-01T07:51:19.923274Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "N, D_in, H, D_out = 64, 1000, 100, 10\n",
    "\n",
    "x = torch.randn(N, D_in)\n",
    "y = torch.randn(N, D_out)\n",
    "\n",
    "model = torch.nn.Sequential(\n",
    "    torch.nn.Linear(D_in, H),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(H, D_out),\n",
    ")\n",
    "\n",
    "loss_fn = torch.nn.MSELoss(reduction='sum')\n",
    "\n",
    "learning_rate = 1e-4\n",
    "for t in range(500):\n",
    "    y_pred = model(x)\n",
    "\n",
    "    loss = loss_fn(y_pred, y)\n",
    "    if t % 100 == 99:\n",
    "        print(t, loss.item())\n",
    "\n",
    "    model.zero_grad()\n",
    "\n",
    "    loss.backward()\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for param in model.parameters():\n",
    "            param -= learning_rate * param.grad"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 创建模型，继承自`nn.Module`模块\n",
    "\n",
    "- `__init__`初始化定义各个层\n",
    "- `forward`方法，模型的推导\n",
    "    - 当使用`autograd`时，模型会自动创建`backward`方法\n",
    "- 可学习的参数，`net.parameters()`    \n",
    "- `torch.nn`只支持批量的数据，而不是单个数据，如`conv2d`层接受维度为`nSamples x nChannels x Height x Width`的张量\n",
    "    - `input.unsqueeze(0)`给单个数据增加维度"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-31T13:46:19.853518Z",
     "start_time": "2020-03-31T13:46:19.846862Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-31T13:53:22.052219Z",
     "start_time": "2020-03-31T13:53:22.042982Z"
    }
   },
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 6, 3)\n",
    "        self.conv2 = nn.Conv2d(6, 16, 3)\n",
    "        self.fc1 = nn.Linear(16 * 6 * 6, 120)\n",
    "        self.fc2 = nn.Linear(120, 84)\n",
    "        self.fc3 = nn.Linear(84, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))\n",
    "        x = F.max_pool2d(F.relu(self.conv2(x)), 2)  # 正方形的核，指定一个值也可以\n",
    "        x = x.view(-1, self.num_flat_features(x))\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x\n",
    "\n",
    "    def num_flat_features(self, x):\n",
    "        size = x.size()[1:]\n",
    "        num_features = 1\n",
    "        for s in size:\n",
    "            num_features *= s\n",
    "        return num_features\n",
    "\n",
    "\n",
    "net = Net()\n",
    "print(net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-31T13:57:54.621006Z",
     "start_time": "2020-03-31T13:57:54.614959Z"
    }
   },
   "outputs": [],
   "source": [
    "params = list(net.parameters())\n",
    "print(len(params))  # 两个conv层的权重和偏置，三个fc层的权重和偏置\n",
    "print(params[0].size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-31T13:59:31.896995Z",
     "start_time": "2020-03-31T13:59:31.889187Z"
    }
   },
   "outputs": [],
   "source": [
    "input = torch.randn(1, 1, 32, 32)\n",
    "out = net(input)\n",
    "print(out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-31T14:00:38.350528Z",
     "start_time": "2020-03-31T14:00:38.343383Z"
    }
   },
   "outputs": [],
   "source": [
    "net.zero_grad()  # 将梯度初始化为0\n",
    "out.backward(torch.randn(1, 10))  # 提供随机的参数为梯度"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 损失函数\n",
    ">整个网络的架构：\n",
    "```\n",
    "input-> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d\n",
    "     -> view -> linear -> relu -> linear -> relu -> linear\n",
    "     -> MSELoss\n",
    "     -> loss\n",
    "```   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-31T14:05:58.235588Z",
     "start_time": "2020-03-31T14:05:58.228530Z"
    }
   },
   "outputs": [],
   "source": [
    "output = net(input)\n",
    "target = torch.randn(10)\n",
    "target = target.view(1, -1)\n",
    "criterion = nn.MSELoss()\n",
    "\n",
    "loss = criterion(output, target)\n",
    "print(loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-31T14:09:25.277933Z",
     "start_time": "2020-03-31T14:09:25.272214Z"
    }
   },
   "outputs": [],
   "source": [
    "print(loss.grad_fn)\n",
    "print(loss.grad_fn.next_functions[0][0])\n",
    "print(loss.grad_fn.next_functions[0][0].next_functions[0][0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-31T14:11:05.016465Z",
     "start_time": "2020-03-31T14:11:05.011022Z"
    }
   },
   "source": [
    "- 反向传播"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-31T14:11:09.013237Z",
     "start_time": "2020-03-31T14:11:09.005885Z"
    }
   },
   "outputs": [],
   "source": [
    "net.zero_grad()\n",
    "\n",
    "print('conv1.bias.grad before backward')\n",
    "print(net.conv1.bias.grad)\n",
    "\n",
    "loss.backward()\n",
    "\n",
    "print('conv1.bias.grad after backward')\n",
    "print(net.conv1.bias.grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 更新参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-31T14:13:19.057536Z",
     "start_time": "2020-03-31T14:13:19.051929Z"
    }
   },
   "outputs": [],
   "source": [
    "# python实现\n",
    "learning_rate = 0.01\n",
    "for f in net.parameters():\n",
    "    f.data.sub_(f.grad.data * learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-31T14:15:46.201957Z",
     "start_time": "2020-03-31T14:15:46.193655Z"
    }
   },
   "outputs": [],
   "source": [
    "# 优化器\n",
    "import torch.optim as optim\n",
    "optimizer = optim.SGD(net.parameters(), lr=0.01)\n",
    "\n",
    "optimizer.zero_grad()\n",
    "output = net(input)\n",
    "loss = criterion(output, target)\n",
    "loss.backward()\n",
    "optimizer.step()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 控制流和权重共享\n",
    "- 定义特别的网络：全连接层后接多个隐藏层，前向传播时随机选择多个隐藏层来传播，重用权重"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-01T08:05:15.849289Z",
     "start_time": "2020-04-01T08:05:15.490441Z"
    }
   },
   "outputs": [],
   "source": [
    "import random\n",
    "import torch\n",
    "\n",
    "\n",
    "class DynamicNet(torch.nn.Module):\n",
    "    def __init__(self, D_in, H, D_out):\n",
    "        super(DynamicNet, self).__init__()\n",
    "        self.input_layer = torch.nn.Linear(D_in, H)\n",
    "        self.middle_layer = torch.nn.Linear(H, H)\n",
    "        self.output_layer = torch.nn.Linear(H, D_out)\n",
    "\n",
    "    def forward(self, x):\n",
    "        h_relu = self.input_layer(x).clamp(min=0)\n",
    "        for _ in range(random.randint(1, 3)):  # 随机选择多个前向传播层，且共用参数\n",
    "            h_relu = self.middle_layer(h_relu).clamp(min=0)\n",
    "        y_pred = self.output_layer(h_relu)\n",
    "        return y_pred\n",
    "\n",
    "\n",
    "N, D_in, H, D_out = 64, 1000, 100, 10\n",
    "\n",
    "x = torch.randn(N, D_in)\n",
    "y = torch.randn(N, D_out)\n",
    "\n",
    "model = DynamicNet(D_in, H, D_out)\n",
    "\n",
    "criterion = torch.nn.MSELoss(reduction='sum')\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)\n",
    "for t in range(500):\n",
    "    y_pred = model(x)\n",
    "\n",
    "    loss = criterion(y_pred, y)\n",
    "    if t % 100 == 99:\n",
    "        print(t, loss.item())\n",
    "\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 创建图片分类器"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 加载数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-01T06:09:27.964441Z",
     "start_time": "2020-04-01T06:09:27.861559Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-01T06:09:34.901050Z",
     "start_time": "2020-04-01T06:09:34.305959Z"
    }
   },
   "outputs": [],
   "source": [
    "transform = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n",
    "])\n",
    "trainset = torchvision.datasets.CIFAR10(root='./datasets',\n",
    "                                        train=True,\n",
    "                                        download=False,\n",
    "                                        transform=transform)\n",
    "trainloader = torch.utils.data.DataLoader(trainset,\n",
    "                                          batch_size=4,\n",
    "                                          shuffle=True,\n",
    "                                          num_workers=2)\n",
    "testset = torchvision.datasets.CIFAR10(root='./datasets',\n",
    "                                       train=False,\n",
    "                                       download=False,\n",
    "                                       transform=transform)\n",
    "testloader = torch.utils.data.DataLoader(testset,\n",
    "                                         batch_size=4,\n",
    "                                         shuffle=False,\n",
    "                                         num_workers=2)\n",
    "\n",
    "classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',\n",
    "           'ship', 'truck')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-01T06:11:10.457353Z",
     "start_time": "2020-04-01T06:11:10.146187Z"
    }
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "def imshow(img):\n",
    "    img = img / 2 + 0.5\n",
    "    npimg = img.numpy()\n",
    "    plt.imshow(np.transpose(npimg, (1, 2, 0)))\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "dataiter = iter(trainloader)\n",
    "images, labels = dataiter.next()\n",
    "imshow(torchvision.utils.make_grid(images))\n",
    "print(\" \".join('%5s' % classes[labels[j]] for j in range(4)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-01T06:11:46.611796Z",
     "start_time": "2020-04-01T06:11:46.602707Z"
    }
   },
   "outputs": [],
   "source": [
    "images.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 定义CNN模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-01T06:25:02.052347Z",
     "start_time": "2020-04-01T06:25:02.043485Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "\n",
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(3, 6, 5)  # 输入3通道，输出6通道，滤波器尺寸5\n",
    "        self.pool = nn.MaxPool2d(2, 2)\n",
    "        self.conv2 = nn.Conv2d(6, 16, 5)\n",
    "        self.fc1 = nn.Linear(16 * 5 * 5, 120)\n",
    "        self.fc2 = nn.Linear(120, 84)\n",
    "        self.fc3 = nn.Linear(84, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.pool(F.relu(self.conv1(x)))\n",
    "        x = self.pool(F.relu(self.conv2(x)))\n",
    "        x = x.view(-1, 16 * 5 * 5)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "net = Net()\n",
    "print(net)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 定义损失函数和优化器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-01T06:25:03.687268Z",
     "start_time": "2020-04-01T06:25:03.681919Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch.optim as optim\n",
    "criterion = nn.CrossEntropyLoss()  # 输入为每一类的原始权重，未经softmax标准化\n",
    "optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-01T06:27:27.480081Z",
     "start_time": "2020-04-01T06:25:15.357185Z"
    }
   },
   "outputs": [],
   "source": [
    "for epoch in range(3):\n",
    "    running_loss = 0.0\n",
    "    for i, data in enumerate(trainloader, 0):\n",
    "        inputs, labels = data\n",
    "        optimizer.zero_grad()\n",
    "        outputs = net(inputs)\n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward()  # 反向传播\n",
    "\n",
    "        optimizer.step()  # 参数更新\n",
    "\n",
    "        running_loss += loss.item()\n",
    "        if i % 2000 == 1999:\n",
    "            print('[%d, %5d] loss: %.3f' %\n",
    "                  (epoch + 1, i + 1, running_loss / 2000))\n",
    "            running_loss = 0\n",
    "print(\"Finished Training\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-01T06:33:27.905310Z",
     "start_time": "2020-04-01T06:33:27.899839Z"
    }
   },
   "outputs": [],
   "source": [
    "PATH = './models/cifar_net.pth'\n",
    "torch.save(net.state_dict(), PATH)  # 保存权重"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 测试模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-01T06:33:33.931134Z",
     "start_time": "2020-04-01T06:33:33.839708Z"
    }
   },
   "outputs": [],
   "source": [
    "dataiter = iter(testloader)\n",
    "images, labels = dataiter.next()\n",
    "\n",
    "imshow(torchvision.utils.make_grid(images))\n",
    "print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-01T06:33:41.528482Z",
     "start_time": "2020-04-01T06:33:41.521423Z"
    }
   },
   "outputs": [],
   "source": [
    "net = Net()\n",
    "net.load_state_dict(torch.load(PATH))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-01T06:33:45.010779Z",
     "start_time": "2020-04-01T06:33:45.003773Z"
    }
   },
   "outputs": [],
   "source": [
    "outputs = net(images)\n",
    "\n",
    "_, predicted = torch.max(outputs, 1)\n",
    "\n",
    "print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(4)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-01T06:36:30.104133Z",
     "start_time": "2020-04-01T06:36:28.040467Z"
    }
   },
   "outputs": [],
   "source": [
    "correct = 0\n",
    "total = 0\n",
    "with torch.no_grad():\n",
    "    for data in testloader:\n",
    "        images, labels = data\n",
    "        outputs = net(images)\n",
    "        _, predicted = torch.max(outputs.data, 1)\n",
    "        total += labels.size(0)\n",
    "        correct += (predicted == labels).sum().item()\n",
    "print('Accuracy of the network on the 10000 test images: %d %%' %\n",
    "      (100 * correct / total))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-01T06:36:41.510401Z",
     "start_time": "2020-04-01T06:36:39.406625Z"
    }
   },
   "outputs": [],
   "source": [
    "class_correct = list(0. for i in range(10))\n",
    "class_total = list(0. for i in range(10))\n",
    "with torch.no_grad():\n",
    "    for data in testloader:\n",
    "        images, labels = data\n",
    "        outputs = net(images)\n",
    "        _, predicted = torch.max(outputs, 1)\n",
    "        c = (predicted == labels).squeeze()\n",
    "        for i in range(4):\n",
    "            label = labels[i]\n",
    "            class_correct[label] += c[i].item()\n",
    "            class_total[label] += 1\n",
    "\n",
    "for i in range(10):\n",
    "    print('Accuracy of %5s : %2d %%' %\n",
    "          (classes[i], 100 * class_correct[i] / class_total[i]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GPU上训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-01T06:36:51.600909Z",
     "start_time": "2020-04-01T06:36:51.594883Z"
    }
   },
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-01T06:41:36.560821Z",
     "start_time": "2020-04-01T06:39:40.062896Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "net = Net()\n",
    "net.to(device)\n",
    "\n",
    "for epoch in range(5):\n",
    "    running_loss = 0.0\n",
    "    for i, data in enumerate(trainloader, 0):\n",
    "        inputs, labels = data[0].to(device), data[1].to(device)\n",
    "        optimizer.zero_grad()\n",
    "        outputs = net(inputs)\n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        running_loss += loss.item()\n",
    "        if i % 4000 == 1999:\n",
    "            print('[%d, %5d] loss: %.3f' %\n",
    "                  (epoch + 1, i + 1, running_loss / 2000))\n",
    "            running_loss = 0\n",
    "print(\"Finished Training\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-01T06:41:55.673281Z",
     "start_time": "2020-04-01T06:41:53.191068Z"
    }
   },
   "outputs": [],
   "source": [
    "class_correct = list(0. for i in range(10))\n",
    "class_total = list(0. for i in range(10))\n",
    "with torch.no_grad():\n",
    "    for data in testloader:\n",
    "        images, labels = data[0].to(device), data[1].to(device)\n",
    "        outputs = net(images)\n",
    "        _, predicted = torch.max(outputs, 1)\n",
    "        c = (predicted == labels).squeeze()\n",
    "        for i in range(4):\n",
    "            label = labels[i]\n",
    "            class_correct[label] += c[i].item()\n",
    "            class_total[label] += 1\n",
    "\n",
    "for i in range(10):\n",
    "    print('Accuracy of %5s : %2d %%' %\n",
    "          (classes[i], 100 * class_correct[i] / class_total[i]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 更高阶模型 \n",
    "https://www.kaggle.com/aleksandradeis/classifying-cifar-10-images-with-densenet/notebook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-01T08:42:45.750607Z",
     "start_time": "2020-04-01T08:42:45.252583Z"
    }
   },
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "\n",
    "train_transform = transforms.Compose([\n",
    "    transforms.RandomHorizontalFlip(p=.40),\n",
    "    transforms.RandomRotation(30),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])\n",
    "])\n",
    "\n",
    "test_transform = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])\n",
    "])\n",
    "\n",
    "\n",
    "def get_training_dataloader(train_transform,\n",
    "                            batch_size=128,\n",
    "                            num_workers=0,\n",
    "                            shuffle=True):\n",
    "\n",
    "    transform_train = train_transform\n",
    "    cifar10_training = torchvision.datasets.CIFAR10(root='./datasets',\n",
    "                                                    train=True,\n",
    "                                                    download=False,\n",
    "                                                    transform=transform_train)\n",
    "    cifar10_training_loader = DataLoader(cifar10_training,\n",
    "                                         shuffle=shuffle,\n",
    "                                         num_workers=num_workers,\n",
    "                                         batch_size=batch_size)\n",
    "\n",
    "    return cifar10_training_loader\n",
    "\n",
    "\n",
    "def get_testing_dataloader(test_transform,\n",
    "                           batch_size=128,\n",
    "                           num_workers=0,\n",
    "                           shuffle=True):\n",
    "\n",
    "    transform_test = test_transform\n",
    "    cifar10_test = torchvision.datasets.CIFAR10(root='./datasets',\n",
    "                                                train=False,\n",
    "                                                download=False,\n",
    "                                                transform=transform_test)\n",
    "    cifar10_test_loader = DataLoader(cifar10_test,\n",
    "                                     shuffle=shuffle,\n",
    "                                     num_workers=num_workers,\n",
    "                                     batch_size=batch_size)\n",
    "\n",
    "    return cifar10_test_loader\n",
    "\n",
    "\n",
    "trainloader = get_training_dataloader(train_transform)\n",
    "testloader = get_testing_dataloader(test_transform)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-01T08:42:51.403311Z",
     "start_time": "2020-04-01T08:42:51.397610Z"
    }
   },
   "outputs": [],
   "source": [
    "classes_dict = {\n",
    "    0: 'airplane',\n",
    "    1: 'automobile',\n",
    "    2: 'bird',\n",
    "    3: 'cat',\n",
    "    4: 'deer',\n",
    "    5: 'dog',\n",
    "    6: 'frog',\n",
    "    7: 'horse',\n",
    "    8: 'ship',\n",
    "    9: 'truck'\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-01T08:47:00.330338Z",
     "start_time": "2020-04-01T08:46:59.924579Z"
    }
   },
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(5, 5, figsize=(10, 10))\n",
    "for batch_idx, (inputs, labels) in enumerate(trainloader):\n",
    "    for im in range(25):\n",
    "        image = inputs[im].permute(1, 2, 0)\n",
    "        i = im // 5\n",
    "        j = im % 5\n",
    "        axs[i, j].imshow(image.numpy())\n",
    "        axs[i, j].axis('off')\n",
    "        axs[i, j].set_title(classes_dict[int(labels[im].numpy())])\n",
    "    break\n",
    "plt.suptitle('CIFAR-10 Images')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-01T08:59:29.131829Z",
     "start_time": "2020-04-01T08:59:29.114448Z"
    }
   },
   "outputs": [],
   "source": [
    "#\"\"\"Bottleneck layers. Although each layer only produces k\n",
    "#output feature-maps, it typically has many more inputs. It\n",
    "#has been noted in [37, 11] that a 1×1 convolution can be in-\n",
    "#troduced as bottleneck layer before each 3×3 convolution\n",
    "#to reduce the number of input feature-maps, and thus to\n",
    "#improve computational efficiency.\"\"\"\n",
    "class Bottleneck(nn.Module):\n",
    "    def __init__(self, in_channels, growth_rate):\n",
    "        super().__init__()\n",
    "        #\"\"\"In  our experiments, we let each 1×1 convolution\n",
    "        #produce 4k feature-maps.\"\"\"\n",
    "        inner_channel = 4 * growth_rate\n",
    "\n",
    "        #\"\"\"We find this design especially effective for DenseNet and\n",
    "        #we refer to our network with such a bottleneck layer, i.e.,\n",
    "        #to the BN-ReLU-Conv(1×1)-BN-ReLU-Conv(3×3) version of H ` ,\n",
    "        #as DenseNet-B.\"\"\"\n",
    "        self.bottle_neck = nn.Sequential(\n",
    "            nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True),\n",
    "            nn.Conv2d(in_channels, inner_channel, kernel_size=1, bias=False),\n",
    "            nn.BatchNorm2d(inner_channel), nn.ReLU(inplace=True),\n",
    "            nn.Conv2d(inner_channel,\n",
    "                      growth_rate,\n",
    "                      kernel_size=3,\n",
    "                      padding=1,\n",
    "                      bias=False))\n",
    "\n",
    "    def forward(self, x):\n",
    "        return torch.cat([x, self.bottle_neck(x)], 1)\n",
    "\n",
    "\n",
    "#\"\"\"We refer to layers between blocks as transition\n",
    "#layers, which do convolution and pooling.\"\"\"\n",
    "class Transition(nn.Module):\n",
    "    def __init__(self, in_channels, out_channels):\n",
    "        super().__init__()\n",
    "        #\"\"\"The transition layers used in our experiments\n",
    "        #consist of a batch normalization layer and an 1×1\n",
    "        #convolutional layer followed by a 2×2 average pooling\n",
    "        #layer\"\"\".\n",
    "        self.down_sample = nn.Sequential(\n",
    "            nn.BatchNorm2d(in_channels),\n",
    "            nn.Conv2d(in_channels, out_channels, 1, bias=False),\n",
    "            nn.AvgPool2d(2, stride=2))\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.down_sample(x)\n",
    "\n",
    "\n",
    "#DesneNet-BC\n",
    "#B stands for bottleneck layer(BN-RELU-CONV(1x1)-BN-RELU-CONV(3x3))\n",
    "#C stands for compression factor(0<=theta<=1)\n",
    "class DenseNet(nn.Module):\n",
    "    def __init__(self,\n",
    "                 block,\n",
    "                 nblocks,\n",
    "                 growth_rate=12,\n",
    "                 reduction=0.5,\n",
    "                 num_class=10):\n",
    "        super().__init__()\n",
    "        self.growth_rate = growth_rate\n",
    "\n",
    "        #\"\"\"Before entering the first dense block, a convolution\n",
    "        #with 16 (or twice the growth rate for DenseNet-BC)\n",
    "        #output channels is performed on the input images.\"\"\"\n",
    "        inner_channels = 2 * growth_rate\n",
    "\n",
    "        #For convolutional layers with kernel size 3×3, each\n",
    "        #side of the inputs is zero-padded by one pixel to keep\n",
    "        #the feature-map size fixed.\n",
    "        self.conv1 = nn.Conv2d(3,\n",
    "                               inner_channels,\n",
    "                               kernel_size=3,\n",
    "                               padding=1,\n",
    "                               bias=False)\n",
    "\n",
    "        self.features = nn.Sequential()\n",
    "\n",
    "        for index in range(len(nblocks) - 1):\n",
    "            self.features.add_module(\n",
    "                \"dense_block_layer_{}\".format(index),\n",
    "                self._make_dense_layers(block, inner_channels, nblocks[index]))\n",
    "            inner_channels += growth_rate * nblocks[index]\n",
    "\n",
    "            #\"\"\"If a dense block contains m feature-maps, we let the\n",
    "            #following transition layer generate θm output feature-\n",
    "            #maps, where 0 < θ ≤ 1 is referred to as the compression\n",
    "            #fac-tor.\n",
    "            out_channels = int(\n",
    "                reduction *\n",
    "                inner_channels)  # int() will automatic floor the value\n",
    "            self.features.add_module(\"transition_layer_{}\".format(index),\n",
    "                                     Transition(inner_channels, out_channels))\n",
    "            inner_channels = out_channels\n",
    "\n",
    "        self.features.add_module(\n",
    "            \"dense_block{}\".format(len(nblocks) - 1),\n",
    "            self._make_dense_layers(block, inner_channels,\n",
    "                                    nblocks[len(nblocks) - 1]))\n",
    "        inner_channels += growth_rate * nblocks[len(nblocks) - 1]\n",
    "        self.features.add_module('bn', nn.BatchNorm2d(inner_channels))\n",
    "        self.features.add_module('activation', nn.ReLU(inplace=True))\n",
    "\n",
    "        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n",
    "\n",
    "        self.linear = nn.Linear(inner_channels, num_class)\n",
    "\n",
    "    def forward(self, x):\n",
    "        output = self.conv1(x)\n",
    "        output = self.features(output)\n",
    "        output = self.avgpool(output)\n",
    "        output = output.view(output.size()[0], -1)\n",
    "        output = self.linear(output)\n",
    "        return output\n",
    "\n",
    "    def _make_dense_layers(self, block, in_channels, nblocks):\n",
    "        dense_block = nn.Sequential()\n",
    "        for index in range(nblocks):\n",
    "            dense_block.add_module('bottle_neck_layer_{}'.format(index),\n",
    "                                   block(in_channels, self.growth_rate))\n",
    "            in_channels += self.growth_rate\n",
    "        return dense_block\n",
    "\n",
    "\n",
    "def densenet121(activation='relu'):\n",
    "    return DenseNet(Bottleneck, [6, 12, 24, 16], growth_rate=32)\n",
    "\n",
    "\n",
    "def densenet169(activation='relu'):\n",
    "    return DenseNet(Bottleneck, [6, 12, 32, 32], growth_rate=32)\n",
    "\n",
    "\n",
    "def densenet201(activation='relu'):\n",
    "    return DenseNet(Bottleneck, [6, 12, 48, 32], growth_rate=32)\n",
    "\n",
    "\n",
    "def densenet161(activation='relu'):\n",
    "    return DenseNet(Bottleneck, [6, 12, 36, 24], growth_rate=48)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-01T08:59:30.344991Z",
     "start_time": "2020-04-01T08:59:30.338709Z"
    }
   },
   "outputs": [],
   "source": [
    "epochs = 50\n",
    "# learning rate\n",
    "learning_rate = 0.001\n",
    "# device to use\n",
    "# don't forget to turn on GPU on kernel's settings\n",
    "device = torch.device('cuda:0' if torch.cuda.is_available() else \"cpu\")\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-01T08:59:51.582267Z",
     "start_time": "2020-04-01T08:59:51.539599Z"
    }
   },
   "outputs": [],
   "source": [
    "model = densenet121()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-01T09:01:08.555480Z",
     "start_time": "2020-04-01T09:01:08.549055Z"
    }
   },
   "outputs": [],
   "source": [
    "from torch.optim import Adam\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "optimizer = Adam(model.parameters(), lr=learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-01T09:01:31.433030Z",
     "start_time": "2020-04-01T09:01:31.281055Z"
    }
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "train_stats = pd.DataFrame(columns=[\n",
    "    'Epoch', 'Time per epoch', 'Avg time per step', 'Train loss',\n",
    "    'Train accuracy', 'Train top-3 accuracy', 'Test loss', 'Test accuracy',\n",
    "    'Test top-3 accuracy'\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-01T10:20:46.731414Z",
     "start_time": "2020-04-01T09:02:58.530834Z"
    }
   },
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "#train the model\n",
    "model.to(device)\n",
    "\n",
    "steps = 0\n",
    "running_loss = 0\n",
    "for epoch in range(epochs):\n",
    "\n",
    "    since = time.time()\n",
    "\n",
    "    train_accuracy = 0\n",
    "    top3_train_accuracy = 0\n",
    "    for inputs, labels in trainloader:\n",
    "        steps += 1\n",
    "        # Move input and label tensors to the default device\n",
    "        inputs, labels = inputs.to(device), labels.to(device)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        logps = model.forward(inputs)\n",
    "        loss = criterion(logps, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        running_loss += loss.item()\n",
    "\n",
    "        # calculate train top-1 accuracy\n",
    "        ps = torch.exp(logps)\n",
    "        top_p, top_class = ps.topk(1, dim=1)\n",
    "        equals = top_class == labels.view(*top_class.shape)\n",
    "        train_accuracy += torch.mean(equals.type(torch.FloatTensor)).item()\n",
    "\n",
    "        # Calculate train top-3 accuracy\n",
    "        np_top3_class = ps.topk(3, dim=1)[1].cpu().numpy()\n",
    "        target_numpy = labels.cpu().numpy()\n",
    "        top3_train_accuracy += np.mean([\n",
    "            1 if target_numpy[i] in np_top3_class[i] else 0\n",
    "            for i in range(0, len(target_numpy))\n",
    "        ])\n",
    "\n",
    "    time_elapsed = time.time() - since\n",
    "\n",
    "    test_loss = 0\n",
    "    test_accuracy = 0\n",
    "    top3_test_accuracy = 0\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        for inputs, labels in testloader:\n",
    "            inputs, labels = inputs.to(device), labels.to(device)\n",
    "            logps = model.forward(inputs)\n",
    "            batch_loss = criterion(logps, labels)\n",
    "\n",
    "            test_loss += batch_loss.item()\n",
    "\n",
    "            # Calculate test top-1 accuracy\n",
    "            ps = torch.exp(logps)\n",
    "            top_p, top_class = ps.topk(1, dim=1)\n",
    "            equals = top_class == labels.view(*top_class.shape)\n",
    "            test_accuracy += torch.mean(equals.type(torch.FloatTensor)).item()\n",
    "\n",
    "            # Calculate test top-3 accuracy\n",
    "            np_top3_class = ps.topk(3, dim=1)[1].cpu().numpy()\n",
    "            target_numpy = labels.cpu().numpy()\n",
    "            top3_test_accuracy += np.mean([\n",
    "                1 if target_numpy[i] in np_top3_class[i] else 0\n",
    "                for i in range(0, len(target_numpy))\n",
    "            ])\n",
    "\n",
    "    print(\n",
    "        f\"Epoch {epoch+1}/{epochs}.. \"\n",
    "        f\"Time per epoch: {time_elapsed:.4f}.. \"\n",
    "        f\"Average time per step: {time_elapsed/len(trainloader):.4f}.. \"\n",
    "        f\"Train loss: {running_loss/len(trainloader):.4f}.. \"\n",
    "        f\"Train accuracy: {train_accuracy/len(trainloader):.4f}.. \"\n",
    "        f\"Top-3 train accuracy: {top3_train_accuracy/len(trainloader):.4f}.. \"\n",
    "        f\"Test loss: {test_loss/len(testloader):.4f}.. \"\n",
    "        f\"Test accuracy: {test_accuracy/len(testloader):.4f}.. \"\n",
    "        f\"Top-3 test accuracy: {top3_test_accuracy/len(testloader):.4f}\")\n",
    "\n",
    "    train_stats = train_stats.append(\n",
    "        {\n",
    "            'Epoch': epoch,\n",
    "            'Time per epoch': time_elapsed,\n",
    "            'Avg time per step': time_elapsed / len(trainloader),\n",
    "            'Train loss': running_loss / len(trainloader),\n",
    "            'Train accuracy': train_accuracy / len(trainloader),\n",
    "            'Train top-3 accuracy': top3_train_accuracy / len(trainloader),\n",
    "            'Test loss': test_loss / len(testloader),\n",
    "            'Test accuracy': test_accuracy / len(testloader),\n",
    "            'Test top-3 accuracy': top3_test_accuracy / len(testloader)\n",
    "        },\n",
    "        ignore_index=True)\n",
    "\n",
    "    running_loss = 0\n",
    "    model.train()\n",
    "    \n",
    "# train_stats.to_csv('train_log_DenseNet121.csv')    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-01T10:51:50.121907Z",
     "start_time": "2020-04-01T10:51:49.960857Z"
    }
   },
   "outputs": [],
   "source": [
    "PATH = './models/cifar_net.pth'\n",
    "torch.save(model.state_dict(), PATH) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-01T10:49:28.752995Z",
     "start_time": "2020-04-01T10:49:28.651300Z"
    }
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(10, 7))\n",
    "ax = plt.axes()\n",
    "\n",
    "plt.title(\"Loss\")\n",
    "plt.xlabel(\"Epochs\")\n",
    "plt.ylabel(\"Loss\")\n",
    "\n",
    "x = range(1, len(train_stats['Train loss'].values) + 1)\n",
    "ax.plot(x, train_stats['Train loss'].values, '-g', label='train loss')\n",
    "ax.plot(x, train_stats['Test loss'].values, '-b', label='test loss')\n",
    "\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-01T10:49:38.861538Z",
     "start_time": "2020-04-01T10:49:38.762240Z"
    }
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(10,7))\n",
    "ax = plt.axes()\n",
    "\n",
    "plt.title(\"Accuracy\")\n",
    "plt.xlabel(\"Epochs\")\n",
    "plt.ylabel(\"Accuracy\");\n",
    "\n",
    "x = range(1, len(train_stats['Train accuracy'].values) + 1)\n",
    "ax.plot(x, train_stats['Train accuracy'].values, '-g', label='train accuracy');\n",
    "ax.plot(x, train_stats['Test accuracy'].values, '-b', label='test accuracy');\n",
    "\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-01T10:56:25.501419Z",
     "start_time": "2020-04-01T10:56:25.494323Z"
    }
   },
   "outputs": [],
   "source": [
    "def view_classify(img, ps, title):\n",
    "    \"\"\"\n",
    "    Function for viewing an image and it's predicted classes\n",
    "    with matplotlib.\n",
    "\n",
    "    INPUT:\n",
    "        img - (tensor) image file\n",
    "        ps - (tensor) predicted probabilities for each class\n",
    "        title - (str) string with true label\n",
    "    \"\"\"\n",
    "    ps = ps.data.numpy().squeeze()\n",
    "\n",
    "    fig, (ax1, ax2) = plt.subplots(figsize=(6, 9), ncols=2)\n",
    "    image = img.permute(1, 2, 0)\n",
    "    ax1.imshow(image.numpy())\n",
    "    ax1.axis('off')\n",
    "    ax2.barh(np.arange(10), ps)\n",
    "    ax2.set_aspect(0.1)\n",
    "    ax2.set_yticks(np.arange(10))\n",
    "    ax2.set_yticklabels(list(classes_dict.values()), size='small')\n",
    "    ax2.set_title(title)\n",
    "    ax2.set_xlim(0, 1.1)\n",
    "\n",
    "    plt.tight_layout()\n",
    "\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-01T10:56:26.599508Z",
     "start_time": "2020-04-01T10:56:26.416586Z"
    }
   },
   "outputs": [],
   "source": [
    "for batch_idx, (inputs, labels) in enumerate(testloader):\n",
    "    inputs, labels = inputs.to(device), labels.to(device)\n",
    "    img = inputs[0]\n",
    "    label_true = labels[0]\n",
    "    ps = model(inputs)\n",
    "    view_classify(img.cpu(), torch.softmax(ps[0].cpu(), dim=0), classes_dict[int(label_true.cpu().numpy())])\n",
    "    \n",
    "    break;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-03T09:21:32.697019Z",
     "start_time": "2020-04-03T09:21:32.220239Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-03T09:22:24.511575Z",
     "start_time": "2020-04-03T09:22:24.506427Z"
    }
   },
   "outputs": [],
   "source": [
    "x = torch.tensor(1.0, requires_grad=True)\n",
    "\n",
    "\n",
    "def u(x):\n",
    "    return x * x\n",
    "\n",
    "\n",
    "def g(u):\n",
    "    return -u"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-03T09:23:05.779460Z",
     "start_time": "2020-04-03T09:23:05.753718Z"
    }
   },
   "outputs": [],
   "source": [
    "dgdx = torch.autograd.grad(g(u(x)),x)\n",
    "print(dgdx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-03T09:24:30.840343Z",
     "start_time": "2020-04-03T09:24:30.833414Z"
    }
   },
   "outputs": [],
   "source": [
    "y = g(u(x))\n",
    "y.backward()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-03T09:24:36.064391Z",
     "start_time": "2020-04-03T09:24:36.058607Z"
    }
   },
   "outputs": [],
   "source": [
    "y.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-03T09:24:49.829686Z",
     "start_time": "2020-04-03T09:24:49.790178Z"
    }
   },
   "outputs": [],
   "source": [
    "u.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T07:02:55.005703Z",
     "start_time": "2020-04-15T07:02:54.506993Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "import torchvision.transforms as transforms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T07:04:30.205997Z",
     "start_time": "2020-04-15T07:04:30.149199Z"
    }
   },
   "outputs": [],
   "source": [
    "x = torch.tensor(1., requires_grad=True)\n",
    "w = torch.tensor(2., requires_grad=True)\n",
    "b = torch.tensor(3., requires_grad=True)\n",
    "\n",
    "y = w * x + b\n",
    "y.backward()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T07:04:54.632425Z",
     "start_time": "2020-04-15T07:04:54.621991Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(2.)\n",
      "tensor(1.)\n",
      "tensor(1.)\n"
     ]
    }
   ],
   "source": [
    "print(x.grad)\n",
    "print(w.grad)\n",
    "print(b.grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 前向传播自动进行；反向求梯度，必须指定张量的`requires_grad`属性，并调用`.backward()`方法\n",
    "\n",
    "- `nn`模块内置层，默认设置了`requires_gras=True`\n",
    "    - `nn.Linear`类的`weight`和`bias`属性"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T07:08:27.724154Z",
     "start_time": "2020-04-15T07:08:27.716440Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "w: Parameter containing:\n",
      "tensor([[-0.0129, -0.1627, -0.3308],\n",
      "        [ 0.1463,  0.3797, -0.5309]], requires_grad=True)\n",
      "b: Parameter containing:\n",
      "tensor([0.5564, 0.1864], requires_grad=True)\n"
     ]
    }
   ],
   "source": [
    "x = torch.randn(10, 3)\n",
    "y = torch.randn(10, 2)\n",
    "\n",
    "linear = nn.Linear(3, 2)\n",
    "print(\"w:\", linear.weight)\n",
    "print(\"b:\", linear.bias)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T07:12:32.469836Z",
     "start_time": "2020-04-15T07:12:32.453306Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss: 1.9151089191436768\n",
      "dL/dw tensor([[-0.2434, -0.1086, -0.7133],\n",
      "        [ 0.2724, -0.0422, -1.3558]])\n",
      "dL/db tensor([ 0.5101, -0.6411])\n"
     ]
    }
   ],
   "source": [
    "criterion = nn.MSELoss()\n",
    "optimizer = torch.optim.SGD(linear.parameters(), lr=0.01)\n",
    "\n",
    "pred = linear(x)\n",
    "\n",
    "loss = criterion(pred, y)\n",
    "print(\"loss:\", loss.item())\n",
    "\n",
    "loss.backward()\n",
    "\n",
    "print(\"dL/dw\", linear.weight.grad)\n",
    "print(\"dL/db\", linear.bias.grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T07:13:03.793843Z",
     "start_time": "2020-04-15T07:13:03.787716Z"
    }
   },
   "outputs": [],
   "source": [
    "optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T07:13:19.972137Z",
     "start_time": "2020-04-15T07:13:19.966183Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dL/dw tensor([[-0.2434, -0.1086, -0.7133],\n",
      "        [ 0.2724, -0.0422, -1.3558]])\n",
      "dL/db tensor([ 0.5101, -0.6411])\n"
     ]
    }
   ],
   "source": [
    "print(\"dL/dw\", linear.weight.grad)\n",
    "print(\"dL/db\", linear.bias.grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T07:14:33.395762Z",
     "start_time": "2020-04-15T07:14:33.390317Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss after 1 step optimization: 1.8837978839874268\n"
     ]
    }
   ],
   "source": [
    "pred = linear(x)\n",
    "loss = criterion(pred, y)\n",
    "print(\"loss after 1 step optimization:\", loss.item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T07:15:42.211143Z",
     "start_time": "2020-04-15T07:15:42.201932Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Parameter containing:\n",
       "tensor([[-0.0105, -0.1616, -0.3237],\n",
       "        [ 0.1435,  0.3801, -0.5173]], requires_grad=True)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "linear.weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T07:16:43.956659Z",
     "start_time": "2020-04-15T07:16:43.950843Z"
    }
   },
   "outputs": [],
   "source": [
    "x = np.array([[1, 2], [3, 4]])\n",
    "\n",
    "y = torch.from_numpy(x)\n",
    "z = y.numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T07:16:48.257963Z",
     "start_time": "2020-04-15T07:16:48.252571Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[1, 2],\n",
       "       [3, 4]])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "z"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T07:16:53.581095Z",
     "start_time": "2020-04-15T07:16:53.575863Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "139860155684784"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "id(z)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T07:16:59.726101Z",
     "start_time": "2020-04-15T07:16:59.720679Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "139860155684656"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "id(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T07:17:06.079311Z",
     "start_time": "2020-04-15T07:17:06.070473Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "139860154068784"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "id(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 数据管道"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T07:26:49.626982Z",
     "start_time": "2020-04-15T07:26:49.140313Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "# 加载数据集\n",
    "\n",
    "train_dataset = torchvision.datasets.CIFAR10(\n",
    "    root='datasets',\n",
    "    train=True,\n",
    "    transform=transforms.ToTensor(),\n",
    "    download=True,  # 本地文件夹有文件时，不会重新下载\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T07:29:33.295202Z",
     "start_time": "2020-04-15T07:29:33.289140Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'torch.Tensor'> torch.Size([3, 32, 32])\n",
      "6\n"
     ]
    }
   ],
   "source": [
    "# 返回张量数据\n",
    "\n",
    "image, label = train_dataset[0]\n",
    "\n",
    "print(type(image), image.shape)\n",
    "print(label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T07:30:52.583201Z",
     "start_time": "2020-04-15T07:30:52.571147Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([64, 3, 32, 32])"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 将数据转换成dataloader，并指定批次\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    dataset=train_dataset,\n",
    "    batch_size=64,\n",
    "    shuffle=True,\n",
    ")\n",
    "\n",
    "dataiter = iter(train_loader)\n",
    "\n",
    "images, labels = next(dataiter)\n",
    "images.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T07:35:52.963369Z",
     "start_time": "2020-04-15T07:35:52.957225Z"
    }
   },
   "outputs": [],
   "source": [
    "# 创建自己的数据集\n",
    "\n",
    "\n",
    "class CustomDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self):\n",
    "        # 指定数据所在的路径或文件名列表\n",
    "        pass\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        # 1.从文件中读取数据，如numpy.fromfile, PIL.Image.open\n",
    "        # 2.预处理数据，如 torchvision.Transform\n",
    "        # 3.返回数据对，如 图像及标签\n",
    "        pass\n",
    "\n",
    "    def __len__(self):\n",
    "        # 数据量\n",
    "        return 100\n",
    "\n",
    "\n",
    "dataset = CustomDataset()\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    dataset=dataset,\n",
    "    batch_size=64,\n",
    "    shuffle=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 预训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T07:45:05.275714Z",
     "start_time": "2020-04-15T07:45:05.126671Z"
    }
   },
   "outputs": [],
   "source": [
    "# 加载预训练模型\n",
    "resnet = torchvision.models.resnet18(pretrained=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T07:45:17.317692Z",
     "start_time": "2020-04-15T07:45:16.523838Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([64, 10])\n"
     ]
    }
   ],
   "source": [
    "# 模型参数固定\n",
    "for param in resnet.parameters():\n",
    "    param.requires_grad = False\n",
    "\n",
    "# 将输出层替换成自定义结构，如 10 分类的输出层    \n",
    "resnet.fc = nn.Linear(resnet.fc.in_features, 10)\n",
    "\n",
    "# 进行预测\n",
    "images = torch.randn(64, 3, 244, 244)\n",
    "outputs = resnet(images)\n",
    "print(outputs.size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 保存模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T07:47:55.508414Z",
     "start_time": "2020-04-15T07:47:55.482258Z"
    }
   },
   "outputs": [],
   "source": [
    "# 保存模型\n",
    "torch.save(resnet, \"models/torch.testsavemodel.ckpt\")\n",
    "\n",
    "# 加载模型\n",
    "model = torch.load(\"models/torch.testsavemodel.ckpt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 仅仅保存模型参数\n",
    "torch.save(resnet.state_dict(), 'params.ckpt')\n",
    "\n",
    "# 加载模型参数\n",
    "resnet.load_state_dict(torch.load('params.ckpt'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 线性回归"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T07:50:19.232822Z",
     "start_time": "2020-04-15T07:50:19.226071Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T07:56:58.487193Z",
     "start_time": "2020-04-15T07:56:58.472322Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " Epoch [10/60], Loss:0.3196\n",
      " Epoch [20/60], Loss:0.2811\n",
      " Epoch [30/60], Loss:0.2744\n",
      " Epoch [40/60], Loss:0.2728\n",
      " Epoch [50/60], Loss:0.2721\n",
      " Epoch [60/60], Loss:0.2715\n"
     ]
    }
   ],
   "source": [
    "# 模型参数\n",
    "input_size = 1\n",
    "output_size = 1\n",
    "num_epochs = 60\n",
    "learning_rate = 0.001\n",
    "\n",
    "# 训练数据\n",
    "x_train = np.array(\n",
    "    [[3.3], [4.4], [5.5], [6.71], [6.93], [4.168], [9.779], [6.182], [7.59],\n",
    "     [2.167], [7.042], [10.791], [5.313], [7.997], [3.1]],\n",
    "    dtype=np.float32)\n",
    "y_train = np.array(\n",
    "    [[1.7], [2.76], [2.09], [3.19], [1.694], [1.573], [3.366], [2.596], [2.53],\n",
    "     [1.221], [2.827], [3.465], [1.65], [2.904], [1.3]],\n",
    "    dtype=np.float32)\n",
    "\n",
    "\n",
    "# 线性模型\n",
    "model = nn.Linear(input_size, output_size)\n",
    "\n",
    "# 损失函数\n",
    "criterion = nn.MSELoss()\n",
    "\n",
    "# 优化器\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    inputs = torch.from_numpy(x_train)\n",
    "    targets = torch.from_numpy(y_train)\n",
    "\n",
    "    # 前向计算\n",
    "    outputs = model(inputs)\n",
    "\n",
    "    # 损失\n",
    "    loss = criterion(outputs, targets)\n",
    "\n",
    "    # 梯度归零\n",
    "    optimizer.zero_grad()\n",
    "    \n",
    "    # 反向传播，计算梯度\n",
    "    loss.backward()\n",
    "    \n",
    "    # 更新参数\n",
    "    optimizer.step()\n",
    "\n",
    "    if (epoch + 1) % 10 == 0:\n",
    "        print(\" Epoch [{}/{}], Loss:{:.4f}\".format(epoch + 1, num_epochs,\n",
    "                                                   loss.item()))\n",
    "        \n",
    "# 保存模型\n",
    "# torch.save(model.state_dict(), 'model.ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T08:01:17.909001Z",
     "start_time": "2020-04-15T08:01:17.753771Z"
    }
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3deXxU9b3/8deHEAmbUgErAmEQUWQNGFQKWpSlCLhcFcVSW722XJcq3ioUjYobCBev2w8qv7ihNdUiVlxwF1FERVllLRIJEEEElM2ABPK9f0wYMsOETJKZnFnez8cjj8n5zsmcj+Pwzjff8z3fY845REQk8dXyugAREYkOBbqISJJQoIuIJAkFuohIklCgi4gkidpeHbhJkybO5/N5dXgRkYS0YMGCrc65puGe8yzQfT4f8+fP9+rwIiIJyczWlfechlxERJKEAl1EJEko0EVEkoRnY+jhFBcXU1hYyN69e70uRYCMjAxatGhBenq616WISATiKtALCwtp2LAhPp8PM/O6nJTmnGPbtm0UFhbSunVrr8sRkQjE1ZDL3r17ady4scI8DpgZjRs31l9LIgkkrgIdUJjHEf2/EEkscRfoIiLJam/xAR56bzUbt++JyetHHOhmlmZmi8zsjTDP1TGzf5rZGjObZ2a+aBZZkwoLC7nwwgtp27Ytbdq0YcSIEezbty/svhs3buTSSy+t8DUHDhzI9u3bq1TP3XffzYMPPljhfg0aNDji89u3b+dvf/tblWoQkeqbNn8D7e58m8c++JqPV2+JyTEq00MfAaws57lrgB+dcycBDwMTqltYRPLywOeDWrX8j3l51Xo55xwXX3wxF110EV9//TWrV69m9+7d5OTkHLbv/v37OeGEE5g+fXqFr/vmm2/SqFGjatVWXQp0EW/s2FOMb/RMRk3/CoCLsk5g6OmZMTlWRIFuZi2AQcCT5exyIfBs6ffTgT4W6wHYvDwYPhzWrQPn/I/Dh1cr1GfNmkVGRgZXX301AGlpaTz88MM8/fTTFBUVMXXqVIYMGcL5559P//79KSgooGPHjgAUFRVx2WWX0blzZy6//HLOOOOMwNIGPp+PrVu3UlBQwKmnnsqf/vQnOnToQP/+/dmzx/+n1xNPPEH37t3p0qULl1xyCUVFRUesde3atfTo0YPu3btz5513Btp3795Nnz596NatG506deLVV18FYPTo0eTn55OVlcXIkSPL3U9EomfKR/l0uefdwPbHI8/hkaFdY3dA51yFX/hD+jSgN/BGmOeXAS3KbOcDTcLsNxyYD8zPzMx0oVasWHFYW7latXLOH+XBX61aRf4aIR599FF38803H9aelZXllixZ4p555hnXvHlzt23bNuecc2vXrnUdOnRwzjk3ceJEN3z4cOecc0uXLnVpaWnuyy+/LC21lduyZYtbu3atS0tLc4sWLXLOOTdkyBD397//3Tnn3NatWwPHy8nJcY899phzzrkxY8a4iRMnHlbT+eef75599lnnnHOTJk1y9evXd845V1xc7Hbs2OGcc27Lli2uTZs2rqSkJKjWI+0XqlL/T0TEOefc5h17XKu/vhH4Gjczev+OgPmunKyusIduZoOB751zC460W7jfFWF+eeQ657Kdc9lNm4ZdLCxy69dXrj0CzrmwMzvKtvfr149jjz32sH0++eQThg4dCkDHjh3p3Llz2GO0bt2arKwsAE477TQKCgoAWLZsGWeddRadOnUiLy+P5cuXH7HWuXPncsUVVwBw5ZVXBtV6++2307lzZ/r27cu3337L5s2bw/43RbKfiFTOfW+s4PRxHwS2v8zpy20DT62RY0dyYVFP4AIzGwhkAEeb2fPOud+V2acQaAkUmllt4Bjgh6hXW1Zmpn+YJVx7FXXo0IGXX345qG3nzp1s2LCBNm3asGDBAurXrx/2Z12EN9uuU6dO4Pu0tLTAkMtVV13FjBkz6NKlC1OnTmX27NkVvla4Xz55eXls2bKFBQsWkJ6ejs/nCzuXPNL9RCQyBVt/oveDswPbOQNP5U9nn1ijNVTYQ3fO3eaca+Gc8wFDgVkhYQ7wGvCH0u8vLd0nsoSrqrFjoV694LZ69fztVdSnTx+Kiop47rnnADhw4AC33HILV111FfVCjxWiV69eTJs2DYAVK1awdOnSSh17165dNGvWjOLiYvIiOA/Qs2dPXnzxRYCg/Xfs2MFxxx1Heno6H374IetKf+k1bNiQXbt2VbifiFTejS8sCgrzr+7uX+NhDtWYh25m95rZBaWbTwGNzWwN8BdgdDSKO6JhwyA3F1q1AjP/Y26uv72KzIxXXnmFl156ibZt23LyySeTkZHBuHHjKvzZ66+/ni1bttC5c2cmTJhA586dOeaYYyI+9n333ccZZ5xBv379aNeuXYX7P/roo0yePJnu3buzY8eOQPuwYcOYP38+2dnZ5OXlBV6rcePG9OzZk44dOzJy5Mhy9xORyC37dge+0TN5fclGAB4c0oWC8YM4OsOb9Y8s1h3p8mRnZ7vQG1ysXLmSU0+tmbGmaDtw4ADFxcVkZGSQn59Pnz59WL16NUcddZTXpVVLIv8/EYmVkhLH0NzP+aLAP7L8i3rpfHZbHzLS02J+bDNb4JzLDvdcXC3OlciKioo455xzKC4uxjnH448/nvBhLiKH+zR/K799Yl5g++mrsjm33S89rOgQBXqUNGzYULfUE0lixQdK6PvQR6zb5r9GpN3xDZl501mk1YqfNY8U6CIiFXh72SaufX5hYHv6tT3I9h0+fdlrCnQRkXLs2XeArve9y97iEgDOPrkpz17dPW5XIlWgi4iE8Y9567n9lUPTj9+5+WxOOb6hhxVVTIEuIlLG9qJ9ZN37XmB7yGktmDiki4cVRU7roYdIS0sjKysr8FVQUMD8+fO56aabAJg9ezaffvppYP8ZM2awYsWKSh+nvOVuD7ZHujSviETPpFlfB4X5nFHnJEyYg3roh6lbty6LFy8OavP5fGRn+6d9zp49mwYNGvCrX/0K8Af64MGDad++fVTriHRpXhGpvu927OXMBw6tv3LDOW0Y+ZvEu9hOPfQIzJ49m8GDB1NQUMCUKVN4+OGHycrK4qOPPuK1115j5MiRZGVlkZ+fT35+PgMGDOC0007jrLPOYtWqVUD5y92Wp+zSvFOnTuXiiy9mwIABtG3bllGjRgX2e/fdd+nRowfdunVjyJAh7N69OzZvgkiSGvPqsqAwX3BH34QMc4jjHvo9ry9nxcadUX3N9icczZjzOxxxnz179gRWQ2zdujWvvPJK4Dmfz8e1115LgwYNuPXWWwG44IILGDx4cGB4pE+fPkyZMoW2bdsyb948rr/+embNmsWIESO47rrr+P3vf8/kyZMrXfvixYtZtGgRderU4ZRTTuHGG2+kbt263H///bz//vvUr1+fCRMm8NBDD3HXXXdV+vVFUk3+lt30+d+PAtt3DW7Pf/Zq7WFF1Re3ge6VcEMukdq9ezeffvopQ4YMCbT9/PPPgH+524MrOV555ZX89a9/rdRr9+nTJ7A2TPv27Vm3bh3bt29nxYoV9OzZE4B9+/bRo0ePKtUukiqcc1z3/ELeXv5doG3ZPb+hQZ3Ej8O4/S+oqCcdj0pKSmjUqFG5vxCqM3c1dNnd/fv345yjX79+vPDCC1V+XZFU8lXhdi6YNDew/ejQLC7Mau5hRdGlMfRKCl2Gtuz20UcfTevWrXnppZcAf09gyZIlQPnL3VbHmWeeydy5c1mzZg3gX09m9erVUXltkWRSUuK4aPLcQJgf17AO/75/QFKFOSjQK+3888/nlVdeISsrizlz5jB06FAmTpxI165dyc/PJy8vj6eeeoouXbrQoUOHwL06y1vutjqaNm3K1KlTueKKK+jcuTNnnnlm4CSsiPj9Y956Trz9TRZv2A7A1Ku780VOX+rUjv3KiDVNy+fKEen/iSSqon37aX/XO4HtTs2PYcYNPeNqMa2q0PK5IpJSrs9bwJtLD530vPv89lzVM7FnsERCgS4iSWPr7p/Jvv/9oLa1DwyM28W0oi3uAt05lzJvfrzzajhOpCoGPPIxq747NGHh8WHdOK9TMw8rqnlxFegZGRls27aNxo0bK9Q95pxj27ZtZGRkeF2KyBF9s2U355a5QAigYPwgj6rxVlwFeosWLSgsLGTLli1elyL4f8G2aNHC6zJEyuUbPTNo++XrenBaq/i78URNqTDQzSwD+BioU7r/dOfcmJB9rgImAt+WNk1yzj1Z2WLS09Np3Tr5T1yISPUsWPcDlzz+WVBbqvbKy4qkh/4zcK5zbreZpQOfmNlbzrnPQ/b7p3Puz9EvUUTkkNBe+Qe3/Jo2TcMvR51qKrywyPkdXMIvvfRLZ8tEpEa9vWxTUJi3Pa4BBeMHJVaY5+WBzwe1avkfo3TV+EERjaGbWRqwADgJmOycmxdmt0vM7GxgNfDfzrkNYV5nODAcIDMzs8pFi0jqcM7R+rY3g9q+zOlL04Z1yvmJOJWXB8OHQ1GRf3vdOv82wLBhUTlEpa4UNbNGwCvAjc65ZWXaGwO7nXM/m9m1wGXOuXOP9FrhrhQVESnrmblruef1Q3cEO6/j8Tz+u9M8rKgafD5/iIdq1QoKCiJ+mahdKeqc225ms4EBwLIy7dvK7PYEMKEyrysiUlbxgRLa5rwV1Lbi3t9Q76i4mphXOevXV669CiocQzezpqU9c8ysLtAXWBWyT9nZ+xcAK6NWoYiklHtfXxEU5tf+ug0F4wcldpgDlDfMHMXh50jeoWbAs6Xj6LWAac65N8zsXmC+c+414CYzuwDYD/wAXBW1CkUkJez+eT8dx7wT1LZm7HnUTkuSRWHHjg0eQweoV8/fHiVxtdqiiKSma6Z+yQervg9s33dRR648s1X1XzgvD3Jy/MMamZn+8IzSCUiv6tFqiyISl77fuZfTx30Q1Ba1xbRqYFZJpQ0bFtNjq4cuIp749cQPWbft0PDDk7/Ppm/7X0bvAFGaVRJv1EMXkbjx9eZd9Hv446C2mFy2XwOzSuKNAl1EakzoZfszbuhJVstGsTlYZmb4HnoSX9SYJKePRSSeff7NtqAwr1O7FgXjB8UuzMF/wrFeveC2KM8qiTfqoYtITIX2yj8a2ZtWjevH/sAHTz7G0yyXGFOgi0hMvL5kIze+sCiw3an5Mbx+Y6+aLSLGs0rijQJdRKIq3GJaC+/sx7H1j/KootShQBeRqPn/H+XzwFuHVga5KOsEHhna1cOKUosCXUSqbd/+Ek6+I3gxrVX3DSAjPc2jilKTZrmIVEaMb1CQiO6YsTQozG/q05aC8YMU5h5QD10kUvF4KbmHdu4tpvPd7wa15Y8bSFqtKFy2L1WiS/9FIpWkl5JXxe+enMcna7YGtidc0onLuyfvBTvxRJf+i0RDCl5KHmrTjj30eGBWUFtMLtuXKlGgi0QqBS8lL+uMce+zeefPge2pV3en9ynHeViRhNJJUZFIpeCl5AArN+3EN3pmUJgXjB+kMI9D6qGLRCoFLyUPvWz/jRt70bH5MR5VIxVRoItURopcSj53zVaGPTkvsH1M3XSWjOnvYUUSCQW6iAQJ7ZXPGXUOLY+tV87eEk8U6CICwL8WFvKXaUsC2919v+Cla3/lYUVSWQp0kRRXUuI48fbgxbSW3NWfY+qle1SRVFWFs1zMLMPMvjCzJWa23MzuCbNPHTP7p5mtMbN5ZuaLRbEiEl2TZn0dFOaXZbegYPwghXmCiqSH/jNwrnNut5mlA5+Y2VvOuc/L7HMN8KNz7iQzGwpMAC6PQb0iEgV7iw/Q7s63g9q0mFbiqzDQnX9tgN2lm+mlX6HrBVwI3F36/XRgkpmZ82pdAREp16jpS5g2vzCwfWv/k/nzuW09rEiiJaIxdDNLAxYAJwGTnXPzQnZpDmwAcM7tN7MdQGNga8jrDAeGA2SmyNV1IvFie9E+su59L6jtm3EDqaXFtJJGRIHunDsAZJlZI+AVM+vonFtWZpdwn4jDeufOuVwgF/yLc1WhXhGpgtCpiA9f3oX/6NrCo2okVio1y8U5t93MZgMDgLKBXgi0BArNrDZwDPBDtIoUkapZsXEnAx+bE9SmxbSSV4WBbmZNgeLSMK8L9MV/0rOs14A/AJ8BlwKzNH4u4q3QXvn4t/8fQ7evgk7bU+Jq11QUSQ+9GfBs6Th6LWCac+4NM7sXmO+cew14Cvi7ma3B3zMfGrOKReSIZq3azH9ODb7XQMGEwYc2UvimHMlON7gQSSKhvfLnP3yMXl+8e/iOKXhTjmShG1yIJLmpc9dy9+srgtoKxg+CWueH/4EUuilHKtF66CKxUgM3lHbO4Rs9MyjM3/vvsw+d+CxverCmDSclBbpILBy8ofS6deDcoRtKRzHU75yxjNa3Ba/BUjB+EG1/2fBQQ4relCNVaQxdJBZieEPp/QdKOCnnraC2+Xf0pUmDOuF/IC8vpW7KkeyONIauQBeJhVq1/D3zUGZQUlLll71o8lwWb9ge2G7eqC5zR59b5deTxHOkQNeQS6qogfFcKSPKY9fbi/bhGz0zKMxX3TdAYS5BNMslFRwczy0q8m8fHM8F/ekdK2PHBr/nUOWx69CpiKc2O5q3RpxV3QolCamHngpycoKDBfzbOTne1JMKhg2D3Fz/mLmZ/zE3t1K/QNd8v/uwMP9m3ECFuZRLY+ipIEbjuRI7oUE+oMPxTLnyNI+qkXiiC4tSXWZm+BkXmoscdz5evYXfP/1FUJsW05JIKdBTQRTHcyV2QnvluvGEVJYCPRUcHLfVXOS49OynBYx5bXlQm3rlUhUK9FQxbJgCPA6F9sqn/K4bAzo286gaSXQKdBEP3Pavr3jhiw1BbeqVS3Up0EVqkHPusPVX3rixFx2bH+NRRZJMNA9dkl+cXCU74JGPwy6mpTCXaFEPXZJbHFwl+/P+A5xyx9tBbV/c3ofjjs6okeNL6tCFRZLcYrjqYUSHDznpCRorl+rRhUWSusq7M0+M79izdffPZN//flDbqvsGkJGeFtPjSmrTGLokNw/u2OMbPTMozFs3qU/B+EHVD/M4ORcg8avCQDezlmb2oZmtNLPlZjYizD69zWyHmS0u/borNuWKVFIN3rFn4fofDxtiWfvAQD68tXf1X7wG7oAkiS+SIZf9wC3OuYVm1hBYYGbvOedWhOw3xzk3OPolilRDDV0lGxrkF2adwKNDu0bvAEdaMVMXjEmpCgPdObcJ2FT6/S4zWwk0B0IDXSQ+xfAq2Zfmb2Dk9K+C2mJy0tOjcwGSWCp1UtTMfEBXYF6Yp3uY2RJgI3Crc255mH1EkkZor/yaXq25c3D72BxMK2ZKBCIOdDNrALwM3Oyc2xny9EKglXNut5kNBGYAhy0TZ2bDgeEAmfogSoIa8+oynv0sOFxjPhVRK2ZKBCKah25m6cAbwDvOuYci2L8AyHbObS1vH81Dl0QU2it/6LIuXNytRc0cPC9PK2ZK9eahm5kBTwErywtzMzse2Oycc2Z2Ov7ZM9uqUbNIXBn46BxWbAr+w7TGLxDSiplSgUiGXHoCVwJLzWxxadvtQCaAc24KcClwnZntB/YAQ51Xl6CKRFFJiePE24PXX5lxQ0+yWjbyqCKR8kUyy+UTwCrYZxIwKVpFicQDXbYviUaX/ouE+Onn/XQY805Q27zb+/BLLaYlcU6BLlKGeuWSyBToIsCGH4o4638+DGrTYlqSaBTokvLUK5dkoUCXlPVZ/jaueOLzoLa1DwzEP1NXJPEo0CUlhfbKf9WmMf/405keVSMSHQp0SSnPfVbAXa8GLzOk4RVJFgp0SRmhvfIbzz2JW/qf4lE1ItGnQJek98j7q3nk/a+D2tQrl2SkQJekFtorn/zbbgzq3MyjakRiS4EuSemPz87n/ZWbg9rUK5dkp0CXpHKgxNEmZDGtWbf8mhObNvCoIpGao0CXpNH13nf5sag4qE29ckklCnRJeLt/3k/HkMW0ltzVn2PqpXtUkYg3FOiS0HTZvsghCnRJSIU/FtFrQvBiWl+PPY/0tFoeVSTiPX36xXt5eeDzQa1a/se8vCPu7hs9MyjMT/cdS8H4QQpzSXnqoYu38vKC72a/bp1/Gw67f+aCdT9wyeOfBbVpeEXkEPPq1p/Z2dlu/vz5nhxb4ojP5w/xUK1aQUHBod1Cxsr/2Ks1dwxuH9vaROKQmS1wzmWHe049dPHW+vVHbP/XwkL+Mm1J0FPqlYuEp0AXb2Vmhu+hZ2Ye1iv/n0s7c1l2yxoqTCTxVHgWycxamtmHZrbSzJab2Ygw+5iZPWZma8zsKzPrFptyJemMHQv16gU1PdD3T/iGTg5qKxg/SGEuUoFIeuj7gVuccwvNrCGwwMzec86tKLPPeUDb0q8zgMdLH0WO7OCJz5wcWL8e36jXg56e9l89OL31sR4UJpJ4Kgx059wmYFPp97vMbCXQHCgb6BcCzzn/GdbPzayRmTUr/VmRIxs2jN8WteHT/G1BzRorF6mcSo2hm5kP6ArMC3mqObChzHZhaVtQoJvZcGA4QGZmZuUqlaS0/0AJJ+W8FdQ2Z9Q5tDy2Xjk/ISLliTjQzawB8DJws3NuZ+jTYX7ksPmQzrlcIBf80xYrUackobY5b1J8IPhjoF65SNVFFOhmlo4/zPOcc/8Ks0shUPaMVQtgY/XLk2S0Y08xXe55N6ht6d39aZihxbREqqPCQDczA54CVjrnHipnt9eAP5vZi/hPhu7Q+LmEEzoVsUGd2iy75zceVSOSXCLpofcErgSWmtni0rbbgUwA59wU4E1gILAGKAKujn6pksi+27GXMx/4IKgtf9xA0mqFG60TkaqIZJbLJ4QfIy+7jwNuiFZRklxCe+W9T2nK1KtP96gakeSlK0UlZpZv3MGgxz4JatNJT5HYUaBLTIT2yidc0onLu2uqqkgsKdAlqj5YuZlrng1eRVO9cpGaoUCXqAntlef98Qx6ntTEo2pEUo8CXartmblruef1FUFt6pWL1DwFulSZc47Wt70Z1Pb+X87mpOMaelSRSGpToEuV3DFjKc9/HnxzCvXKRbylQJdKCbeY1vw7+tKkQR2PKhKRgxToErFLHv+UBet+DGy3PLYuc0ad62FFIlJWhXcskkrIy/Pf9LhWLf9jXp7XFUXFrr3F+EbPDArzVfcNUJiHk6SfAUkM6qFHS14eDB8ORUX+7XXr/Ntw6K48CSh0idvzOh7P4787zcOK4liSfgYkcZh/GZaal52d7ebPn1/xjonC5wt/s+NWraCgoKarqbbCH4voNeHDoLZvxg2klhbTKl+SfQYkPpnZAudcdrjnNOQSLevXV649jvlGzwwK85v6tKVg/KDqhXkqDEUk0WdAEpOGXKIlMzN87yyBbrW3ZMN2Lpw8N6gtKlMRU2UoIgk+A5LY1EOPlrFjoV7IfTDr1fO3JwDf6JlBYf7I5VnRm1eek3MozA8qKvK3J5ME/wxI4lOgR8uwYZCb6x8vNfM/5ubGfQ/07WWbDluDpWD8IC7q2jx6B0mVoYgE/QxI8tBJ0WSRl+fv8a5f7/8Tf+zYCoMkNMin/VcPTm99bPRr08lCkajRSdFkd3CMet06cO7QGHU5Jx6nfJQftlcekzAHDUWI1BD10JNBhD3gcItpfXhrb1o3qR/b+qBKf0GIyOGO1ENXoCeDWrX8PfNQZlBSAsAt05bw8sLCoKe1mJZI4jlSoGvaYjI4wnS5fftLOPmO4MW0Ft/Vj0b1jqqh4kSkplQ4hm5mT5vZ92a2rJzne5vZDjNbXPp1V/TLlCMqZ4z6vKseCwrzdsc3pGD8IIW5SJKKpIc+FZgEPHeEfeY45wZHpSKpvINj0aVj1DvanEKXSx6EvYd2+ff9A6hTO82b+kSkRlQY6M65j83MF/tSpFqGDYNhww6bvfIfXZvz8OVZHhUlIjUpWmPoPcxsCbARuNU5tzzcTmY2HBgOkKnLoaPq+117OX3sB0Ftax8YiJkW0xJJFdEI9IVAK+fcbjMbCMwA2obb0TmXC+SCf5ZLFI4tQJ//nU3+lp8C26MGnML1vU/ysCIR8UK1A905t7PM92+a2d/MrIlzbmt1X1uObM33u+n70EdBbZqKKJK6qh3oZnY8sNk558zsdPwzZ7ZVuzI5otCx8pev+xWntfqFR9WISDyoMNDN7AWgN9DEzAqBMUA6gHNuCnApcJ2Z7Qf2AEOdV1crpYAvC35gyJTPAttmsPYB9cpFJLJZLldU8Pwk/NMaJcZCe+U1dtm+iCQEXSmaAGZ+tYkb/rEwsN3u+Ia8ffPZHlYkIvFIgR7Hwi2mNf+OvjRpUMejikQkninQ49STc77h/pkrA9uDOjVj8rBuHlYkIvFO66FXVoxvdlx8oATf6JlBYb7i3t8ozEWkQuqhV0aMb3Z892vLmfppQWD7+t5tGDWgXbVfV0RSg9ZDr4wY3Upt195iOt39blBb/riBpNXSZfsiEkzroUdLDG52/Ienv+Cj1VsC2+P+oxO/PUPr3IhI5SXWGHqMx68rVN6CYlVYaOy7HXvxjZ4ZFOZrHxioMBeRKkucHnqMx68jMnZscA1QpZsd95owi8If9wS2n/pDNn1O/WW0qhSRFJU4PfScnOAgBf92Tk7N1TBsGOTm+sfMzfyPubkR/0JZvXkXvtEzg8K8YPwghbmIREXinBSN4EbI8Sz0sv1Xb+hJl5aNPKpGRBLVkU6KJk4PPYrj1zXp0/ytQWFe/6g0CsYPUpiLSNQlzhh6lMava1Jor/zjkeeQ2bheOXuLiFRP4vTQqzl+XZNeXfxtUJh3admIgvGDFOYiElOJ00OHwI2Q41W4xbQW3dmPX9Q/yqOKRCSVJE4PPc69uvjboDC/uGtzCsYPUpiLSI1JrB56HCo+UELbnLeC2v59/wDq1E7zqCIRSVUK9GrI/TifcW+uCmxPvLQzQ7JbeliRiKQyBXoV/PTzfjqMeSeo7ZtxA6mlxbRExEMK9EqavqCQW19aEth+5urunHPKcR5WJCLiV2Ggm9nTwGDge+dcxzDPG/AoMBAoAq5yzi0M3S/R7dxbTOcyS9zWTU9j5X0DPKxIRCRYJD30qcAk4Llynj8PaFv6dQbweOlj0ggdK599a298TYoAtFcAAAXdSURBVOp7WJGIyOEqDHTn3Mdm5jvCLhcCzzn/ojCfm1kjM2vmnNsUpRo98/2uvZw+9oPA9jW9WnPn4PYeViQiUr5ojKE3BzaU2S4sbTss0M1sODAcIDPO12AZO3MFT8xZG9j+4vY+HHd0hocViYgcWTQCPdzUjrBLODrncoFc8K+2GIVjR926bT/x64mzA9t/HdCO63q38a4gEZEIRSPQC4Gyk69bABuj8Lo1bsSLi3h18aHSl4zpzzF10z2sSEQkctEI9NeAP5vZi/hPhu5ItPHz5Rt3MOixTwLb/3NpZy7TBUIikmAimbb4AtAbaGJmhcAYIB3AOTcFeBP/lMU1+KctXh2rYqPNOcfQ3M+Zt/YHABpm1ObLnL5kpOuyfRFJPJHMcrmigucdcEPUKqohn3+zjaG5nwe2n/h9Nv3a61ZwIpK4Uu5K0f0HSuj38Mes3foTACcd14C3R5xF7TQtPCkiiS2lAv3tZd9x7fMLAtvT/qsHp7c+1sOKRESiJyUCfW/xAbrd9x5F+w4A0POkxjx/zRn4Vy0QEUkOSR/o//xyPX99eWlg+60RZ3Fqs6M9rEhEJDaSNtB3FBXT5d5Di2ld3K05D12W5WFFIiKxlZSBPvnDNUx859+B7TmjzqHlsbpBs4gkt6QK9M0793LGuEOLaV376zaMPq+dhxWJiNScpAn0u19bztRPCwLbX+b0pWnDOt4VJCJSwxI+0Ndu/YlzHpwd2L5j0Kn88awTvStIRMQjCRvozjn+/I9FzFx6aNmYpXf3p2GGFtMSkdSUkIG+tHAH5086tJjWQ5d14eJuLTysSETEewkX6Bt+KAqEeeP6RzF39LlaTEtEhAQM9AZ1atPzpMZc06s157bTYloiIgclXKD/ov5R5P3xTK/LEBGJO1piUEQkSSjQRUSShAJdRCRJKNBFRJKEAl1EJEko0EVEkoQCXUQkSSjQRUSShDnnvDmw2RZgXQS7NgG2xricRKT3pXx6b8LT+1K+RHpvWjnnmoZ7wrNAj5SZzXfOZXtdR7zR+1I+vTfh6X0pX7K8NxpyERFJEgp0EZEkkQiBnut1AXFK70v59N6Ep/elfEnx3sT9GLqIiEQmEXroIiISAQW6iEiSiMtAN7OWZvahma00s+VmNsLrmuKJmaWZ2SIze8PrWuKJmTUys+lmtqr0s9PD65rihZn9d+m/pWVm9oKZZXhdk1fM7Gkz+97MlpVpO9bM3jOzr0sff+FljVUVl4EO7Aducc6dCpwJ3GBm7T2uKZ6MAFZ6XUQcehR42znXDuiC3iMAzKw5cBOQ7ZzrCKQBQ72tylNTgQEhbaOBD5xzbYEPSrcTTlwGunNuk3NuYen3u/D/w2zubVXxwcxaAIOAJ72uJZ6Y2dHA2cBTAM65fc657d5WFVdqA3XNrDZQD9jocT2ecc59DPwQ0nwh8Gzp988CF9VoUVESl4Felpn5gK7APG8riRuPAKOAEq8LiTMnAluAZ0qHo540s/peFxUPnHPfAg8C64FNwA7n3LveVhV3fumc2wT+DiVwnMf1VElcB7qZNQBeBm52zu30uh6vmdlg4Hvn3AKva4lDtYFuwOPOua7ATyTon83RVjoefCHQGjgBqG9mv/O2KomFuA10M0vHH+Z5zrl/eV1PnOgJXGBmBcCLwLlm9ry3JcWNQqDQOXfwL7np+ANeoC+w1jm3xTlXDPwL+JXHNcWbzWbWDKD08XuP66mSuAx0MzP8Y6ErnXMPeV1PvHDO3eaca+Gc8+E/qTXLOaeeFuCc+w7YYGanlDb1AVZ4WFI8WQ+caWb1Sv9t9UEnjEO9Bvyh9Ps/AK96WEuV1fa6gHL0BK4ElprZ4tK2251zb3pYk8S/G4E8MzsK+Aa42uN64oJzbp6ZTQcW4p9BtogkudS9KszsBaA30MTMCoExwHhgmpldg/8X4BDvKqw6XfovIpIk4nLIRUREKk+BLiKSJBToIiJJQoEuIpIkFOgiIklCgS4ikiQU6CIiSeL/AAtf/ILAww2pAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "predicted = model(torch.from_numpy(x_train)).detach().numpy()\n",
    "plt.plot(x_train, y_train, 'ro', label='Original data')\n",
    "plt.plot(x_train, predicted, label='Fitted line')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 逻辑回归"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T08:18:29.146531Z",
     "start_time": "2020-04-15T08:18:29.141071Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cuda')"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 硬件设置\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T08:19:43.491932Z",
     "start_time": "2020-04-15T08:19:43.486614Z"
    }
   },
   "outputs": [],
   "source": [
    "# 参数\n",
    "input_size = 28 * 28\n",
    "hidden_size = 500\n",
    "num_classes = 10\n",
    "num_epochs = 5\n",
    "batch_size = 100\n",
    "learning_rate = 0.001"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T08:05:39.114799Z",
     "start_time": "2020-04-15T08:05:39.082709Z"
    }
   },
   "outputs": [],
   "source": [
    "# 数据集\n",
    "train_dataset = torchvision.datasets.MNIST(\n",
    "    root='datasets',\n",
    "    train=True,\n",
    "    transform=transforms.ToTensor(),\n",
    "    download=True,\n",
    ")\n",
    "test_dataset = torchvision.datasets.MNIST(\n",
    "    root='datasets',\n",
    "    train=False,\n",
    "    transform=transforms.ToTensor(),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T08:07:28.353226Z",
     "start_time": "2020-04-15T08:07:28.347812Z"
    }
   },
   "outputs": [],
   "source": [
    "# 数据管道\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    dataset=train_dataset,\n",
    "    batch_size=batch_size,\n",
    "    shuffle=True,\n",
    ")\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    dataset=test_dataset,\n",
    "    batch_size=batch_size,\n",
    "    shuffle=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T08:19:47.568302Z",
     "start_time": "2020-04-15T08:19:46.461953Z"
    }
   },
   "outputs": [],
   "source": [
    "# 创建模型\n",
    "class NetWork(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size, num_classes):\n",
    "        super(NetWork, self).__init__()\n",
    "        self.fc1 = nn.Linear(input_size, hidden_size)\n",
    "        self.relu = nn.ReLU()\n",
    "        self.fc2 = nn.Linear(hidden_size, num_classes)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.fc1(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.fc2(x)\n",
    "        return x\n",
    "    \n",
    "model = NetWork(input_size, hidden_size, num_classes).to(device)    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T08:27:26.223677Z",
     "start_time": "2020-04-15T08:27:10.454379Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [1/5], Step[100/600], Loss:2.3084\n",
      "Epoch [1/5], Step[200/600], Loss:2.3019\n",
      "Epoch [1/5], Step[300/600], Loss:2.3102\n",
      "Epoch [1/5], Step[400/600], Loss:2.3187\n",
      "Epoch [1/5], Step[500/600], Loss:2.3130\n",
      "Epoch [1/5], Step[600/600], Loss:2.2980\n",
      "Epoch [2/5], Step[100/600], Loss:2.2927\n",
      "Epoch [2/5], Step[200/600], Loss:2.3061\n",
      "Epoch [2/5], Step[300/600], Loss:2.3029\n",
      "Epoch [2/5], Step[400/600], Loss:2.3056\n",
      "Epoch [2/5], Step[500/600], Loss:2.3078\n",
      "Epoch [2/5], Step[600/600], Loss:2.3146\n",
      "Epoch [3/5], Step[100/600], Loss:2.3001\n",
      "Epoch [3/5], Step[200/600], Loss:2.3082\n",
      "Epoch [3/5], Step[300/600], Loss:2.3022\n",
      "Epoch [3/5], Step[400/600], Loss:2.3051\n",
      "Epoch [3/5], Step[500/600], Loss:2.3030\n",
      "Epoch [3/5], Step[600/600], Loss:2.3102\n",
      "Epoch [4/5], Step[100/600], Loss:2.3081\n",
      "Epoch [4/5], Step[200/600], Loss:2.3049\n",
      "Epoch [4/5], Step[300/600], Loss:2.3087\n",
      "Epoch [4/5], Step[400/600], Loss:2.3134\n",
      "Epoch [4/5], Step[500/600], Loss:2.3087\n",
      "Epoch [4/5], Step[600/600], Loss:2.3062\n",
      "Epoch [5/5], Step[100/600], Loss:2.2983\n",
      "Epoch [5/5], Step[200/600], Loss:2.3066\n",
      "Epoch [5/5], Step[300/600], Loss:2.3017\n",
      "Epoch [5/5], Step[400/600], Loss:2.3075\n",
      "Epoch [5/5], Step[500/600], Loss:2.3174\n",
      "Epoch [5/5], Step[600/600], Loss:2.2938\n"
     ]
    }
   ],
   "source": [
    "# 损失函数和优化器\n",
    "loss = nn.CrossEntropyLoss()\n",
    "optimzer = torch.optim.SGD(model.parameters(), lr=learning_rate)\n",
    "\n",
    "# 训练模型\n",
    "total_step = len(train_loader)\n",
    "for epoch in range(num_epochs):\n",
    "    for i, (images, labels) in enumerate(train_loader):\n",
    "        images = images.reshape(-1, 28 * 28).to(device)\n",
    "        labels = labels.to(device)\n",
    "\n",
    "        # 前向传播\n",
    "        outputs = model(images)\n",
    "\n",
    "        # 损失\n",
    "        l = loss(outputs, labels)\n",
    "\n",
    "        # 梯度归零\n",
    "        optimzer.zero_grad()\n",
    "\n",
    "        # 反向传播\n",
    "        l.backward()\n",
    "\n",
    "        # 更新参数\n",
    "        optimizer.step()\n",
    "\n",
    "        if (i + 1) % 100 == 0:\n",
    "            print(\"Epoch [{}/{}], Step[{}/{}], Loss:{:.4f}\".format(\n",
    "                epoch + 1, num_epochs, i + 1, total_step, l.item()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T08:32:42.035325Z",
     "start_time": "2020-04-15T08:32:41.431480Z"
    }
   },
   "outputs": [],
   "source": [
    "# 测试模型\n",
    "with torch.no_grad():\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    for images, labels in test_loader:\n",
    "        images = images.reshape(-1, 28 * 28).to(device)\n",
    "        labels = labels.to(device)\n",
    "        outputs = model(images)\n",
    "        _, predicted = torch.max(outputs.data, 1)\n",
    "        total += labels.size(0)\n",
    "        correct += (predicted == labels).sum().item()\n",
    "\n",
    "    print(\"Accuracy of the network on the 10000 test images: {}%\".format(\n",
    "        100 * correct / total))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 卷积神经网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T08:46:32.791127Z",
     "start_time": "2020-04-15T08:46:13.445036Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [1/5], Step [100/600], Loss: 0.1279\n",
      "Epoch [1/5], Step [200/600], Loss: 0.1408\n",
      "Epoch [1/5], Step [300/600], Loss: 0.0327\n",
      "Epoch [1/5], Step [400/600], Loss: 0.0621\n",
      "Epoch [1/5], Step [500/600], Loss: 0.0393\n",
      "Epoch [1/5], Step [600/600], Loss: 0.0783\n",
      "Epoch [2/5], Step [100/600], Loss: 0.0343\n",
      "Epoch [2/5], Step [200/600], Loss: 0.0319\n",
      "Epoch [2/5], Step [300/600], Loss: 0.1094\n",
      "Epoch [2/5], Step [400/600], Loss: 0.0315\n",
      "Epoch [2/5], Step [500/600], Loss: 0.0819\n",
      "Epoch [2/5], Step [600/600], Loss: 0.0238\n",
      "Epoch [3/5], Step [100/600], Loss: 0.0184\n",
      "Epoch [3/5], Step [200/600], Loss: 0.0192\n",
      "Epoch [3/5], Step [300/600], Loss: 0.0844\n",
      "Epoch [3/5], Step [400/600], Loss: 0.0182\n",
      "Epoch [3/5], Step [500/600], Loss: 0.0296\n",
      "Epoch [3/5], Step [600/600], Loss: 0.0088\n",
      "Epoch [4/5], Step [100/600], Loss: 0.0256\n",
      "Epoch [4/5], Step [200/600], Loss: 0.0142\n",
      "Epoch [4/5], Step [300/600], Loss: 0.0064\n",
      "Epoch [4/5], Step [400/600], Loss: 0.0159\n",
      "Epoch [4/5], Step [500/600], Loss: 0.0355\n",
      "Epoch [4/5], Step [600/600], Loss: 0.0154\n",
      "Epoch [5/5], Step [100/600], Loss: 0.0040\n",
      "Epoch [5/5], Step [200/600], Loss: 0.0416\n",
      "Epoch [5/5], Step [300/600], Loss: 0.0152\n",
      "Epoch [5/5], Step [400/600], Loss: 0.1083\n",
      "Epoch [5/5], Step [500/600], Loss: 0.0531\n",
      "Epoch [5/5], Step [600/600], Loss: 0.0345\n"
     ]
    }
   ],
   "source": [
    "# 创建CNN模型\n",
    "\n",
    "\n",
    "class ConvNet(nn.Module):\n",
    "    def __init__(self, num_classes=10):\n",
    "        super(ConvNet, self).__init__()\n",
    "        self.layer1 = nn.Sequential(\n",
    "            nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),\n",
    "            nn.BatchNorm2d(16),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(kernel_size=2, stride=2),\n",
    "        )\n",
    "        self.layer2 = nn.Sequential(\n",
    "            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),\n",
    "            nn.BatchNorm2d(32),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(kernel_size=2, stride=2),\n",
    "        )\n",
    "        self.fc = nn.Linear(7 * 7 * 32, num_classes)\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = self.layer1(x)\n",
    "        out = self.layer2(out)\n",
    "        out = out.reshape(out.size(0), -1)\n",
    "        out = self.fc(out)\n",
    "        return out\n",
    "\n",
    "\n",
    "model = ConvNet(num_classes).to(device)\n",
    "\n",
    "# 损失函数和优化器\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
    "\n",
    "# 训练模型\n",
    "total_step = len(train_loader)\n",
    "for epoch in range(num_epochs):\n",
    "    for i, (images, labels) in enumerate(train_loader):\n",
    "        images = images.to(device)\n",
    "        labels = labels.to(device)\n",
    "\n",
    "        # 前向计算\n",
    "        outputs = model(images)\n",
    "        loss = criterion(outputs, labels)\n",
    "\n",
    "        # 梯度归零\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        # 反向传播，计算梯度\n",
    "        loss.backward()\n",
    "\n",
    "        # 更新梯度\n",
    "        optimizer.step()\n",
    "\n",
    "        if (i + 1) % 100 == 0:\n",
    "            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(\n",
    "                epoch + 1, num_epochs, i + 1, total_step, loss.item()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T08:47:28.407496Z",
     "start_time": "2020-04-15T08:47:27.886517Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Accuracy of the model on the 10000 test images: 98.87 %\n"
     ]
    }
   ],
   "source": [
    "# 验证模型\n",
    "# eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)\n",
    "\n",
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    for images, labels in test_loader:\n",
    "        images = images.to(device)\n",
    "        labels = labels.to(device)\n",
    "        outputs = model(images)\n",
    "        _, predicted = torch.max(outputs.data, 1)\n",
    "        total += labels.size(0)\n",
    "        correct += (predicted == labels).sum().item()\n",
    "\n",
    "    print('Test Accuracy of the model on the 10000 test images: {} %'.format(\n",
    "        100 * correct / total))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T08:54:21.795281Z",
     "start_time": "2020-04-15T08:54:21.790010Z"
    }
   },
   "source": [
    "# `RNN`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T09:00:44.566525Z",
     "start_time": "2020-04-15T09:00:44.557123Z"
    }
   },
   "outputs": [],
   "source": [
    "sequence_length = 28\n",
    "input_size = 28\n",
    "hidden_size = 128\n",
    "num_layers = 2\n",
    "num_classes = 10\n",
    "batch_size = 100\n",
    "num_epochs = 2\n",
    "learning_rate = 0.01\n",
    "\n",
    "\n",
    "class RNN(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size, num_layers, num_classes):\n",
    "        super(RNN, self).__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "        self.num_layers = num_layers\n",
    "        self.lstm = nn.LSTM(input_size,\n",
    "                            hidden_size,\n",
    "                            num_layers,\n",
    "                            batch_first=True)\n",
    "        self.fc = nn.Linear(hidden_size, num_classes)\n",
    "\n",
    "    def forward(self, x):\n",
    "        h0 = torch.zeros(self.num_layers, x.size(0),\n",
    "                         self.hidden_size).to(device)\n",
    "        c0 = torch.zeros(self.num_layers, x.size(0),\n",
    "                         self.hidden_size).to(device)\n",
    "        out, _ = self.lstm(x, (h0, c0))\n",
    "        out = self.fc(out[:, -1, :])\n",
    "        return out\n",
    "\n",
    "\n",
    "model = RNN(input_size, hidden_size, num_layers, num_classes).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T09:00:53.891012Z",
     "start_time": "2020-04-15T09:00:45.071638Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [1/2], Step [100/600], Loss: 0.8858\n",
      "Epoch [1/2], Step [200/600], Loss: 0.1993\n",
      "Epoch [1/2], Step [300/600], Loss: 0.2062\n",
      "Epoch [1/2], Step [400/600], Loss: 0.0860\n",
      "Epoch [1/2], Step [500/600], Loss: 0.1567\n",
      "Epoch [1/2], Step [600/600], Loss: 0.1580\n",
      "Epoch [2/2], Step [100/600], Loss: 0.0156\n",
      "Epoch [2/2], Step [200/600], Loss: 0.2214\n",
      "Epoch [2/2], Step [300/600], Loss: 0.1073\n",
      "Epoch [2/2], Step [400/600], Loss: 0.0815\n",
      "Epoch [2/2], Step [500/600], Loss: 0.1440\n",
      "Epoch [2/2], Step [600/600], Loss: 0.0486\n"
     ]
    }
   ],
   "source": [
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
    "\n",
    "total_step = len(train_loader)\n",
    "for epoch in range(num_epochs):\n",
    "    for i, (images, labels) in enumerate(train_loader):\n",
    "        images = images.reshape(-1, sequence_length, input_size).to(device)\n",
    "        labels = labels.to(device)\n",
    "\n",
    "        # Forward pass\n",
    "        outputs = model(images)\n",
    "        loss = criterion(outputs, labels)\n",
    "\n",
    "        # Backward and optimize\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        if (i + 1) % 100 == 0:\n",
    "            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(\n",
    "                epoch + 1, num_epochs, i + 1, total_step, loss.item()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T09:01:02.333967Z",
     "start_time": "2020-04-15T09:01:01.760898Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Accuracy of the model on the 10000 test images: 97.68 %\n"
     ]
    }
   ],
   "source": [
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    for images, labels in test_loader:\n",
    "        images = images.reshape(-1, sequence_length, input_size).to(device)\n",
    "        labels = labels.to(device)\n",
    "        outputs = model(images)\n",
    "        _, predicted = torch.max(outputs.data, 1)\n",
    "        total += labels.size(0)\n",
    "        correct += (predicted == labels).sum().item()\n",
    "\n",
    "    print('Test Accuracy of the model on the 10000 test images: {} %'.format(\n",
    "        100 * correct / total))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 双向`RNN`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T09:03:21.855891Z",
     "start_time": "2020-04-15T09:03:21.741371Z"
    }
   },
   "outputs": [],
   "source": [
    "sequence_length = 28\n",
    "input_size = 28\n",
    "hidden_size = 128\n",
    "num_layers = 2\n",
    "num_classes = 10\n",
    "batch_size = 100\n",
    "num_epochs = 2\n",
    "learning_rate = 0.003\n",
    "\n",
    "\n",
    "class BiRNN(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size, num_layers, num_classes):\n",
    "        super(BiRNN, self).__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "        self.num_layers = num_layers\n",
    "        self.lstm = nn.LSTM(input_size,\n",
    "                            hidden_size,\n",
    "                            num_layers,\n",
    "                            batch_first=True,\n",
    "                            bidirectional=True)\n",
    "        self.fc = nn.Linear(hidden_size * 2, num_classes)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # 设置初始化参数\n",
    "        h0 = torch.zeros(self.num_layers * 2, x.size(0),\n",
    "                         self.hidden_size).to(device)\n",
    "        c0 = torch.zeros(self.num_layers * 2, x.size(0),\n",
    "                         self.hidden_size).to(device)\n",
    "\n",
    "        out, _ = self.lstm(x, (h0, c0))\n",
    "\n",
    "        out = self.fc(out[:, -1, :])\n",
    "        return out\n",
    "\n",
    "\n",
    "model = BiRNN(input_size, hidden_size, num_layers, num_classes).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T09:03:52.707390Z",
     "start_time": "2020-04-15T09:03:41.606112Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [1/2], Step [100/600], Loss: 0.8344\n",
      "Epoch [1/2], Step [200/600], Loss: 0.4673\n",
      "Epoch [1/2], Step [300/600], Loss: 0.2923\n",
      "Epoch [1/2], Step [400/600], Loss: 0.0855\n",
      "Epoch [1/2], Step [500/600], Loss: 0.1698\n",
      "Epoch [1/2], Step [600/600], Loss: 0.0913\n",
      "Epoch [2/2], Step [100/600], Loss: 0.1840\n",
      "Epoch [2/2], Step [200/600], Loss: 0.1190\n",
      "Epoch [2/2], Step [300/600], Loss: 0.0809\n",
      "Epoch [2/2], Step [400/600], Loss: 0.1695\n",
      "Epoch [2/2], Step [500/600], Loss: 0.1514\n",
      "Epoch [2/2], Step [600/600], Loss: 0.1883\n"
     ]
    }
   ],
   "source": [
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
    "\n",
    "total_step = len(train_loader)\n",
    "for epoch in range(num_epochs):\n",
    "    for i, (images, labels) in enumerate(train_loader):\n",
    "        images = images.reshape(-1, sequence_length, input_size).to(device)\n",
    "        labels = labels.to(device)\n",
    "\n",
    "        # Forward pass\n",
    "        outputs = model(images)\n",
    "        loss = criterion(outputs, labels)\n",
    "\n",
    "        # Backward and optimize\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        if (i + 1) % 100 == 0:\n",
    "            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(\n",
    "                epoch + 1, num_epochs, i + 1, total_step, loss.item()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T09:03:56.104550Z",
     "start_time": "2020-04-15T09:03:55.431167Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Accuracy of the model on the 10000 test images: 97.33 %\n"
     ]
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    for images, labels in test_loader:\n",
    "        images = images.reshape(-1, sequence_length, input_size).to(device)\n",
    "        labels = labels.to(device)\n",
    "        outputs = model(images)\n",
    "        _, predicted = torch.max(outputs.data, 1)\n",
    "        total += labels.size(0)\n",
    "        correct += (predicted == labels).sum().item()\n",
    "\n",
    "    print('Test Accuracy of the model on the 10000 test images: {} %'.format(\n",
    "        100 * correct / total))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 残差网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T09:21:27.799471Z",
     "start_time": "2020-04-15T09:21:27.794283Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torch.nn as nn\n",
    "import torchvision.transforms as transforms\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T09:22:23.403316Z",
     "start_time": "2020-04-15T09:22:23.396941Z"
    }
   },
   "outputs": [],
   "source": [
    "# 指定硬件设备\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T09:41:12.878652Z",
     "start_time": "2020-04-15T09:41:12.873739Z"
    }
   },
   "outputs": [],
   "source": [
    "# 参数\n",
    "num_epochs = 40\n",
    "batch_size = 100\n",
    "learning_Lrate = 0.001"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T09:25:59.794638Z",
     "start_time": "2020-04-15T09:25:59.052346Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "# 图片预处理\n",
    "transform = transforms.Compose([\n",
    "    transforms.Pad(4),\n",
    "    transforms.RandomHorizontalFlip(),\n",
    "    transforms.RandomCrop(32),\n",
    "    transforms.ToTensor(),\n",
    "])\n",
    "\n",
    "# CIFAR-10 dataset\n",
    "train_dataset = torchvision.datasets.CIFAR10(root='datasets',\n",
    "                                             train=True,\n",
    "                                             transform=transform,\n",
    "                                             download=True)\n",
    "test_dataset = torchvision.datasets.CIFAR10(root='datasets',\n",
    "                                            train=False,\n",
    "                                            transform=transforms.ToTensor())\n",
    "train_loader = torch.utils.data.DataLoader(dataset=train_dataset,\n",
    "                                           batch_size=batch_size,\n",
    "                                           shuffle=True)\n",
    "test_loader = torch.utils.data.DataLoader(dataset=test_dataset,\n",
    "                                          batch_size=batch_size,\n",
    "                                          shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T09:32:34.122397Z",
     "start_time": "2020-04-15T09:32:34.114378Z"
    }
   },
   "outputs": [],
   "source": [
    "# 3*3卷积层\n",
    "def conv3(in_channels, out_channels, stride=1):\n",
    "    return nn.Conv2d(in_channels,\n",
    "                     out_channels,\n",
    "                     kernel_size=3,\n",
    "                     stride=stride,\n",
    "                     padding=1,\n",
    "                     bias=False)\n",
    "\n",
    "\n",
    "# 残差块\n",
    "class ResidualBlock(nn.Module):\n",
    "    def __init__(self, in_channels, out_channels, stride=1, downsample=None):\n",
    "        super(ResidualBlock, self).__init__()\n",
    "        self.conv1 = conv3(in_channels, out_channels, stride)\n",
    "        self.bn1 = nn.BatchNorm2d(out_channels)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "        self.conv2 = conv3(out_channels, out_channels)\n",
    "        self.bn2 = nn.BatchNorm2d(out_channels)\n",
    "        self.downsample = downsample\n",
    "\n",
    "    def forward(self, x):\n",
    "        residual = x\n",
    "        out = self.conv1(x)\n",
    "        out = self.bn1(out)\n",
    "        out = self.relu(out)\n",
    "        out = self.conv2(out)\n",
    "        out = self.bn2(out)\n",
    "        if self.downsample:\n",
    "            residual = self.downsample(x)\n",
    "        out += residual\n",
    "        out = self.relu(out)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T09:34:33.418980Z",
     "start_time": "2020-04-15T09:34:33.409768Z"
    }
   },
   "outputs": [],
   "source": [
    "# ResNet\n",
    "class ResNet(nn.Module):\n",
    "    def __init__(self, block, layers, num_classes=10):\n",
    "        super(ResNet, self).__init__()\n",
    "        self.in_channels = 16\n",
    "        self.conv = conv3(3, 16)\n",
    "        self.bn = nn.BatchNorm2d(16)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "        self.layer1 = self.make_layer(block, 16, layers[0])\n",
    "        self.layer2 = self.make_layer(block, 32, layers[1], 2)\n",
    "        self.layer3 = self.make_layer(block, 64, layers[2], 2)\n",
    "        self.avg_pool = nn.AvgPool2d(8)\n",
    "        self.fc = nn.Linear(64, num_classes)\n",
    "\n",
    "    def make_layer(self, block, out_channels, blocks, stride=1):\n",
    "        downsample = None\n",
    "        if (stride != 1) or (self.in_channels != out_channels):\n",
    "            downsample = nn.Sequential(\n",
    "                conv3(self.in_channels, out_channels, stride=stride),\n",
    "                nn.BatchNorm2d(out_channels))\n",
    "        layers = []\n",
    "        layers.append(block(self.in_channels, out_channels, stride,\n",
    "                            downsample))\n",
    "        self.in_channels = out_channels\n",
    "        for i in range(1, blocks):\n",
    "            layers.append(block(out_channels, out_channels))\n",
    "        return nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = self.conv(x)\n",
    "        out = self.bn(out)\n",
    "        out = self.relu(out)\n",
    "        out = self.layer1(out)\n",
    "        out = self.layer2(out)\n",
    "        out = self.layer3(out)\n",
    "        out = self.avg_pool(out)\n",
    "        out = out.view(out.size(0), -1)\n",
    "        out = self.fc(out)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T09:46:37.809905Z",
     "start_time": "2020-04-15T09:41:18.915803Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [1/40], Step [100/500] Loss: 1.7607\n",
      "Epoch [1/40], Step [200/500] Loss: 1.6299\n",
      "Epoch [1/40], Step [300/500] Loss: 1.4398\n",
      "Epoch [1/40], Step [400/500] Loss: 1.4820\n",
      "Epoch [1/40], Step [500/500] Loss: 1.2206\n",
      "Epoch [2/40], Step [100/500] Loss: 1.1511\n",
      "Epoch [2/40], Step [200/500] Loss: 1.0539\n",
      "Epoch [2/40], Step [300/500] Loss: 1.0607\n",
      "Epoch [2/40], Step [400/500] Loss: 1.0667\n",
      "Epoch [2/40], Step [500/500] Loss: 1.0571\n",
      "Epoch [3/40], Step [100/500] Loss: 1.0334\n",
      "Epoch [3/40], Step [200/500] Loss: 1.0776\n",
      "Epoch [3/40], Step [300/500] Loss: 0.8903\n",
      "Epoch [3/40], Step [400/500] Loss: 0.8334\n",
      "Epoch [3/40], Step [500/500] Loss: 0.9368\n",
      "Epoch [4/40], Step [100/500] Loss: 0.7313\n",
      "Epoch [4/40], Step [200/500] Loss: 0.7194\n",
      "Epoch [4/40], Step [300/500] Loss: 0.5855\n",
      "Epoch [4/40], Step [400/500] Loss: 0.7459\n",
      "Epoch [4/40], Step [500/500] Loss: 0.8809\n",
      "Epoch [5/40], Step [100/500] Loss: 0.7540\n",
      "Epoch [5/40], Step [200/500] Loss: 0.6130\n",
      "Epoch [5/40], Step [300/500] Loss: 0.5990\n",
      "Epoch [5/40], Step [400/500] Loss: 0.4673\n",
      "Epoch [5/40], Step [500/500] Loss: 0.6930\n",
      "Epoch [6/40], Step [100/500] Loss: 0.7237\n",
      "Epoch [6/40], Step [200/500] Loss: 0.6311\n",
      "Epoch [6/40], Step [300/500] Loss: 0.6405\n",
      "Epoch [6/40], Step [400/500] Loss: 0.6288\n",
      "Epoch [6/40], Step [500/500] Loss: 0.5950\n",
      "Epoch [7/40], Step [100/500] Loss: 0.4931\n",
      "Epoch [7/40], Step [200/500] Loss: 0.4923\n",
      "Epoch [7/40], Step [300/500] Loss: 0.4681\n",
      "Epoch [7/40], Step [400/500] Loss: 0.4441\n",
      "Epoch [7/40], Step [500/500] Loss: 0.5339\n",
      "Epoch [8/40], Step [100/500] Loss: 0.4227\n",
      "Epoch [8/40], Step [200/500] Loss: 0.5205\n",
      "Epoch [8/40], Step [300/500] Loss: 0.5784\n",
      "Epoch [8/40], Step [400/500] Loss: 0.5524\n",
      "Epoch [8/40], Step [500/500] Loss: 0.4229\n",
      "Epoch [9/40], Step [100/500] Loss: 0.4240\n",
      "Epoch [9/40], Step [200/500] Loss: 0.5873\n",
      "Epoch [9/40], Step [300/500] Loss: 0.5318\n",
      "Epoch [9/40], Step [400/500] Loss: 0.4893\n",
      "Epoch [9/40], Step [500/500] Loss: 0.4894\n",
      "Epoch [10/40], Step [100/500] Loss: 0.4515\n",
      "Epoch [10/40], Step [200/500] Loss: 0.4354\n",
      "Epoch [10/40], Step [300/500] Loss: 0.4100\n",
      "Epoch [10/40], Step [400/500] Loss: 0.5183\n",
      "Epoch [10/40], Step [500/500] Loss: 0.6160\n",
      "Epoch [11/40], Step [100/500] Loss: 0.4782\n",
      "Epoch [11/40], Step [200/500] Loss: 0.3863\n",
      "Epoch [11/40], Step [300/500] Loss: 0.4927\n",
      "Epoch [11/40], Step [400/500] Loss: 0.6125\n",
      "Epoch [11/40], Step [500/500] Loss: 0.6246\n",
      "Epoch [12/40], Step [100/500] Loss: 0.4183\n",
      "Epoch [12/40], Step [200/500] Loss: 0.4231\n",
      "Epoch [12/40], Step [300/500] Loss: 0.4418\n",
      "Epoch [12/40], Step [400/500] Loss: 0.4224\n",
      "Epoch [12/40], Step [500/500] Loss: 0.5762\n",
      "Epoch [13/40], Step [100/500] Loss: 0.4196\n",
      "Epoch [13/40], Step [200/500] Loss: 0.3575\n",
      "Epoch [13/40], Step [300/500] Loss: 0.3917\n",
      "Epoch [13/40], Step [400/500] Loss: 0.3126\n",
      "Epoch [13/40], Step [500/500] Loss: 0.3146\n",
      "Epoch [14/40], Step [100/500] Loss: 0.4888\n",
      "Epoch [14/40], Step [200/500] Loss: 0.5307\n",
      "Epoch [14/40], Step [300/500] Loss: 0.4032\n",
      "Epoch [14/40], Step [400/500] Loss: 0.5026\n",
      "Epoch [14/40], Step [500/500] Loss: 0.6075\n",
      "Epoch [15/40], Step [100/500] Loss: 0.3765\n",
      "Epoch [15/40], Step [200/500] Loss: 0.4182\n",
      "Epoch [15/40], Step [300/500] Loss: 0.2634\n",
      "Epoch [15/40], Step [400/500] Loss: 0.3653\n",
      "Epoch [15/40], Step [500/500] Loss: 0.3702\n",
      "Epoch [16/40], Step [100/500] Loss: 0.3065\n",
      "Epoch [16/40], Step [200/500] Loss: 0.3962\n",
      "Epoch [16/40], Step [300/500] Loss: 0.3677\n",
      "Epoch [16/40], Step [400/500] Loss: 0.2013\n",
      "Epoch [16/40], Step [500/500] Loss: 0.5495\n",
      "Epoch [17/40], Step [100/500] Loss: 0.3000\n",
      "Epoch [17/40], Step [200/500] Loss: 0.3598\n",
      "Epoch [17/40], Step [300/500] Loss: 0.3482\n",
      "Epoch [17/40], Step [400/500] Loss: 0.3694\n",
      "Epoch [17/40], Step [500/500] Loss: 0.4694\n",
      "Epoch [18/40], Step [100/500] Loss: 0.2322\n",
      "Epoch [18/40], Step [200/500] Loss: 0.3497\n",
      "Epoch [18/40], Step [300/500] Loss: 0.4849\n",
      "Epoch [18/40], Step [400/500] Loss: 0.2544\n",
      "Epoch [18/40], Step [500/500] Loss: 0.2040\n",
      "Epoch [19/40], Step [100/500] Loss: 0.2402\n",
      "Epoch [19/40], Step [200/500] Loss: 0.3862\n",
      "Epoch [19/40], Step [300/500] Loss: 0.4124\n",
      "Epoch [19/40], Step [400/500] Loss: 0.4794\n",
      "Epoch [19/40], Step [500/500] Loss: 0.4090\n",
      "Epoch [20/40], Step [100/500] Loss: 0.2294\n",
      "Epoch [20/40], Step [200/500] Loss: 0.5178\n",
      "Epoch [20/40], Step [300/500] Loss: 0.3876\n",
      "Epoch [20/40], Step [400/500] Loss: 0.3898\n",
      "Epoch [20/40], Step [500/500] Loss: 0.4470\n",
      "Epoch [21/40], Step [100/500] Loss: 0.3514\n",
      "Epoch [21/40], Step [200/500] Loss: 0.2178\n",
      "Epoch [21/40], Step [300/500] Loss: 0.3472\n",
      "Epoch [21/40], Step [400/500] Loss: 0.2641\n",
      "Epoch [21/40], Step [500/500] Loss: 0.3125\n",
      "Epoch [22/40], Step [100/500] Loss: 0.2684\n",
      "Epoch [22/40], Step [200/500] Loss: 0.3719\n",
      "Epoch [22/40], Step [300/500] Loss: 0.2761\n",
      "Epoch [22/40], Step [400/500] Loss: 0.3740\n",
      "Epoch [22/40], Step [500/500] Loss: 0.2046\n",
      "Epoch [23/40], Step [100/500] Loss: 0.2777\n",
      "Epoch [23/40], Step [200/500] Loss: 0.5028\n",
      "Epoch [23/40], Step [300/500] Loss: 0.2641\n",
      "Epoch [23/40], Step [400/500] Loss: 0.2904\n",
      "Epoch [23/40], Step [500/500] Loss: 0.3232\n",
      "Epoch [24/40], Step [100/500] Loss: 0.3179\n",
      "Epoch [24/40], Step [200/500] Loss: 0.2584\n",
      "Epoch [24/40], Step [300/500] Loss: 0.3471\n",
      "Epoch [24/40], Step [400/500] Loss: 0.3267\n",
      "Epoch [24/40], Step [500/500] Loss: 0.3893\n",
      "Epoch [25/40], Step [100/500] Loss: 0.3460\n",
      "Epoch [25/40], Step [200/500] Loss: 0.2453\n",
      "Epoch [25/40], Step [300/500] Loss: 0.2146\n",
      "Epoch [25/40], Step [400/500] Loss: 0.1947\n",
      "Epoch [25/40], Step [500/500] Loss: 0.1843\n",
      "Epoch [26/40], Step [100/500] Loss: 0.2996\n",
      "Epoch [26/40], Step [200/500] Loss: 0.2863\n",
      "Epoch [26/40], Step [300/500] Loss: 0.2991\n",
      "Epoch [26/40], Step [400/500] Loss: 0.3122\n",
      "Epoch [26/40], Step [500/500] Loss: 0.2019\n",
      "Epoch [27/40], Step [100/500] Loss: 0.1367\n",
      "Epoch [27/40], Step [200/500] Loss: 0.1750\n",
      "Epoch [27/40], Step [300/500] Loss: 0.2844\n",
      "Epoch [27/40], Step [400/500] Loss: 0.3017\n",
      "Epoch [27/40], Step [500/500] Loss: 0.2434\n",
      "Epoch [28/40], Step [100/500] Loss: 0.2859\n",
      "Epoch [28/40], Step [200/500] Loss: 0.2684\n",
      "Epoch [28/40], Step [300/500] Loss: 0.1621\n",
      "Epoch [28/40], Step [400/500] Loss: 0.2081\n",
      "Epoch [28/40], Step [500/500] Loss: 0.2937\n",
      "Epoch [29/40], Step [100/500] Loss: 0.2929\n",
      "Epoch [29/40], Step [200/500] Loss: 0.1482\n",
      "Epoch [29/40], Step [300/500] Loss: 0.2725\n",
      "Epoch [29/40], Step [400/500] Loss: 0.3629\n",
      "Epoch [29/40], Step [500/500] Loss: 0.3045\n",
      "Epoch [30/40], Step [100/500] Loss: 0.2841\n",
      "Epoch [30/40], Step [200/500] Loss: 0.2255\n",
      "Epoch [30/40], Step [300/500] Loss: 0.1529\n",
      "Epoch [30/40], Step [400/500] Loss: 0.1822\n",
      "Epoch [30/40], Step [500/500] Loss: 0.1607\n",
      "Epoch [31/40], Step [100/500] Loss: 0.2022\n",
      "Epoch [31/40], Step [200/500] Loss: 0.2220\n",
      "Epoch [31/40], Step [300/500] Loss: 0.2554\n",
      "Epoch [31/40], Step [400/500] Loss: 0.2186\n",
      "Epoch [31/40], Step [500/500] Loss: 0.2062\n",
      "Epoch [32/40], Step [100/500] Loss: 0.2307\n",
      "Epoch [32/40], Step [200/500] Loss: 0.3077\n",
      "Epoch [32/40], Step [300/500] Loss: 0.2876\n",
      "Epoch [32/40], Step [400/500] Loss: 0.2581\n",
      "Epoch [32/40], Step [500/500] Loss: 0.2615\n",
      "Epoch [33/40], Step [100/500] Loss: 0.2354\n",
      "Epoch [33/40], Step [200/500] Loss: 0.2560\n",
      "Epoch [33/40], Step [300/500] Loss: 0.2520\n",
      "Epoch [33/40], Step [400/500] Loss: 0.2441\n",
      "Epoch [33/40], Step [500/500] Loss: 0.3266\n",
      "Epoch [34/40], Step [100/500] Loss: 0.1832\n",
      "Epoch [34/40], Step [200/500] Loss: 0.1276\n",
      "Epoch [34/40], Step [300/500] Loss: 0.1760\n",
      "Epoch [34/40], Step [400/500] Loss: 0.2144\n",
      "Epoch [34/40], Step [500/500] Loss: 0.1353\n",
      "Epoch [35/40], Step [100/500] Loss: 0.2229\n",
      "Epoch [35/40], Step [200/500] Loss: 0.2077\n",
      "Epoch [35/40], Step [300/500] Loss: 0.1868\n",
      "Epoch [35/40], Step [400/500] Loss: 0.2625\n",
      "Epoch [35/40], Step [500/500] Loss: 0.1639\n",
      "Epoch [36/40], Step [100/500] Loss: 0.0998\n",
      "Epoch [36/40], Step [200/500] Loss: 0.2043\n",
      "Epoch [36/40], Step [300/500] Loss: 0.4490\n",
      "Epoch [36/40], Step [400/500] Loss: 0.2509\n",
      "Epoch [36/40], Step [500/500] Loss: 0.2561\n",
      "Epoch [37/40], Step [100/500] Loss: 0.1361\n",
      "Epoch [37/40], Step [200/500] Loss: 0.2192\n",
      "Epoch [37/40], Step [300/500] Loss: 0.2510\n",
      "Epoch [37/40], Step [400/500] Loss: 0.2843\n",
      "Epoch [37/40], Step [500/500] Loss: 0.1470\n",
      "Epoch [38/40], Step [100/500] Loss: 0.2164\n",
      "Epoch [38/40], Step [200/500] Loss: 0.1769\n",
      "Epoch [38/40], Step [300/500] Loss: 0.2075\n",
      "Epoch [38/40], Step [400/500] Loss: 0.3146\n",
      "Epoch [38/40], Step [500/500] Loss: 0.2501\n",
      "Epoch [39/40], Step [100/500] Loss: 0.1348\n",
      "Epoch [39/40], Step [200/500] Loss: 0.2414\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [39/40], Step [300/500] Loss: 0.1406\n",
      "Epoch [39/40], Step [400/500] Loss: 0.3536\n",
      "Epoch [39/40], Step [500/500] Loss: 0.3434\n",
      "Epoch [40/40], Step [100/500] Loss: 0.0927\n",
      "Epoch [40/40], Step [200/500] Loss: 0.2798\n",
      "Epoch [40/40], Step [300/500] Loss: 0.1346\n",
      "Epoch [40/40], Step [400/500] Loss: 0.1457\n",
      "Epoch [40/40], Step [500/500] Loss: 0.2270\n"
     ]
    }
   ],
   "source": [
    "model = ResNet(ResidualBlock, [2, 2, 2]).to(device)\n",
    "\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
    "\n",
    "\n",
    "def update_lr(optimizer, lr):\n",
    "    for param_group in optimizer.param_groups:\n",
    "        param_group['lr'] = lr\n",
    "\n",
    "\n",
    "total_step = len(train_loader)\n",
    "curr_lr = learning_rate\n",
    "for epoch in range(num_epochs):\n",
    "    for i, (images, labels) in enumerate(train_loader):\n",
    "        images = images.to(device)\n",
    "        labels = labels.to(device)\n",
    "\n",
    "        # Forward pass\n",
    "        outputs = model(images)\n",
    "        loss = criterion(outputs, labels)\n",
    "\n",
    "        # Backward and optimize\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        if (i + 1) % 100 == 0:\n",
    "            print(\"Epoch [{}/{}], Step [{}/{}] Loss: {:.4f}\".format(\n",
    "                epoch + 1, num_epochs, i + 1, total_step, loss.item()))\n",
    "\n",
    "    # Decay learning rate\n",
    "    if (epoch + 1) % 100 == 0:\n",
    "        curr_lr /= 3\n",
    "        update_lr(optimizer, curr_lr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-15T09:48:37.266699Z",
     "start_time": "2020-04-15T09:48:36.204834Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of the model on the test images: 86.99 %\n"
     ]
    }
   ],
   "source": [
    "# Test the model\n",
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    for images, labels in test_loader:\n",
    "        images = images.to(device)\n",
    "        labels = labels.to(device)\n",
    "        outputs = model(images)\n",
    "        _, predicted = torch.max(outputs.data, 1)\n",
    "        total += labels.size(0)\n",
    "        correct += (predicted == labels).sum().item()\n",
    "\n",
    "    print('Accuracy of the model on the test images: {} %'.format(100 * correct / total))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 语言模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-16T08:05:13.973224Z",
     "start_time": "2020-04-16T08:05:13.513013Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "from torch.nn.utils import clip_grad_norm_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-16T08:36:57.684166Z",
     "start_time": "2020-04-16T08:36:57.675543Z"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "\n",
    "# 处理文本文件辅助类\n",
    "class Dictionary:\n",
    "    def __init__(self):\n",
    "        self.word2index = {}\n",
    "        self.idx2word = {}\n",
    "        self.idx = 0\n",
    "\n",
    "    def add_word(self, word):\n",
    "        if not word in self.word2index:\n",
    "            self.word2index[word] = self.idx\n",
    "            self.idx2word[self.idx] = word\n",
    "            self.idx += 1\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.word2index)\n",
    "\n",
    "\n",
    "class Corpus:\n",
    "    def __init__(self):\n",
    "        self.dictionary = Dictionary()\n",
    "        self.raw_text = []\n",
    "\n",
    "    def get_data(self, path, batch_size=20):\n",
    "        # 创建词典\n",
    "        with open(path, 'r') as f:\n",
    "            tokens = 0\n",
    "            for line in f:\n",
    "                self.raw_text.append(line)\n",
    "\n",
    "                words = line.split() + ['<eos>']\n",
    "                tokens += len(words)  # 文本中总单词数量\n",
    "                for word in words:\n",
    "                    self.dictionary.add_word(word)\n",
    "                    \n",
    "        # 文本向量化\n",
    "        ids = torch.LongTensor(tokens)\n",
    "        token = 0\n",
    "\n",
    "        with open(path, 'r') as f:\n",
    "            for line in f:\n",
    "                words = line.split() + ['<eos>']\n",
    "                for word in words:\n",
    "                    ids[token] = self.dictionary.word2index[word]\n",
    "                    token += 1\n",
    "        num_batches = ids.size(0) // batch_size\n",
    "        ids = ids[:num_batches * batch_size]\n",
    "        return ids.view(batch_size, -1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-16T08:37:01.046288Z",
     "start_time": "2020-04-16T08:36:58.087041Z"
    }
   },
   "outputs": [],
   "source": [
    "# 加载文本，并向量化\n",
    "path = 'datasets/lm.txt'\n",
    "corpus = Corpus()\n",
    "ids = corpus.get_data(path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-16T08:37:42.531701Z",
     "start_time": "2020-04-16T08:37:42.525795Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " aer banknote berlitz calloway centrust cluett fromstein gitano guterman hydro-quebec ipo kia memotec mlx nahb punts rake regatta rubens sim snack-food ssangyong swapo wachter \n",
      "\n",
      "tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14])\n",
      "['aer', 'banknote', 'berlitz', 'calloway', 'centrust', 'cluett', 'fromstein', 'gitano', 'guterman', 'hydro-quebec', 'ipo', 'kia', 'memotec', 'mlx', 'nahb']\n"
     ]
    }
   ],
   "source": [
    "# 原始文本\n",
    "print(corpus.raw_text[0])\n",
    "\n",
    "# 文本向量\n",
    "print(ids[0][0:15])\n",
    "\n",
    "print([corpus.dictionary.idx2word[idx.item()] for idx in ids[0][0:15]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-16T08:39:03.328121Z",
     "start_time": "2020-04-16T08:39:03.322352Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cuda')"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 指定硬件\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-16T08:47:46.169777Z",
     "start_time": "2020-04-16T08:47:46.163883Z"
    }
   },
   "outputs": [],
   "source": [
    "# 参数\n",
    "embed_size = 128\n",
    "hidden_size = 1024\n",
    "num_layers = 1\n",
    "num_epochs = 5\n",
    "num_samples = 1000\n",
    "batch_size = 20\n",
    "seq_length = 30\n",
    "learning_rate = 0.001\n",
    "\n",
    "num_batchs = ids.shape[1] // seq_length\n",
    "vocab_size=len(corpus.dictionary)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-16T08:48:24.084952Z",
     "start_time": "2020-04-16T08:48:24.078474Z"
    }
   },
   "outputs": [],
   "source": [
    "# 基于RNN的语言模型\n",
    "\n",
    "\n",
    "class RNNLM(nn.Module):\n",
    "    def __init__(self, vocab_size, embed_size, hidden_size, num_layers):\n",
    "        super(RNNLM, self).__init__()\n",
    "        self.embed = nn.Embedding(vocab_size, embed_size)\n",
    "        self.lstm = nn.LSTM(embed_size,\n",
    "                            hidden_size,\n",
    "                            num_layers,\n",
    "                            batch_first=True)\n",
    "        self.linear = nn.Linear(hidden_size, vocab_size)\n",
    "\n",
    "    def forward(self, x, h):\n",
    "        x = self.embed(x)\n",
    "        out, (h, c) = self.lstm(x, h)\n",
    "        out = out.reshape(out.size(0) * out.size(1), out.size(2))\n",
    "        out = self.linear(out)\n",
    "        return out, (h, c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-16T08:48:33.984922Z",
     "start_time": "2020-04-16T08:48:31.963637Z"
    }
   },
   "outputs": [],
   "source": [
    "model = RNNLM(vocab_size, embed_size, hidden_size, num_layers).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-16T08:49:18.356480Z",
     "start_time": "2020-04-16T08:49:18.351255Z"
    }
   },
   "outputs": [],
   "source": [
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-16T08:49:53.378077Z",
     "start_time": "2020-04-16T08:49:53.373231Z"
    }
   },
   "outputs": [],
   "source": [
    "def detach(states):\n",
    "    return [state.detach() for state in states]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-16T08:59:49.349657Z",
     "start_time": "2020-04-16T08:58:26.101211Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [1/5], Step [0/1549], Loss: 9.1072, Perplexity: 9020.009\n",
      "Epoch [1/5], Step [100/1549], Loss: 6.1745, Perplexity: 480.344\n",
      "Epoch [1/5], Step [200/1549], Loss: 6.0718, Perplexity: 433.448\n",
      "Epoch [1/5], Step [300/1549], Loss: 5.9661, Perplexity: 389.977\n",
      "Epoch [1/5], Step [400/1549], Loss: 5.8306, Perplexity: 340.549\n",
      "Epoch [1/5], Step [500/1549], Loss: 5.3281, Perplexity: 206.043\n",
      "Epoch [1/5], Step [600/1549], Loss: 5.2820, Perplexity: 196.769\n",
      "Epoch [1/5], Step [700/1549], Loss: 5.5364, Perplexity: 253.771\n",
      "Epoch [1/5], Step [800/1549], Loss: 5.2764, Perplexity: 195.657\n",
      "Epoch [1/5], Step [900/1549], Loss: 5.2341, Perplexity: 187.564\n",
      "Epoch [1/5], Step [1000/1549], Loss: 5.2817, Perplexity: 196.696\n",
      "Epoch [1/5], Step [1100/1549], Loss: 5.4569, Perplexity: 234.375\n",
      "Epoch [1/5], Step [1200/1549], Loss: 5.2746, Perplexity: 195.311\n",
      "Epoch [1/5], Step [1300/1549], Loss: 5.2142, Perplexity: 183.871\n",
      "Epoch [1/5], Step [1400/1549], Loss: 4.9683, Perplexity: 143.788\n",
      "Epoch [1/5], Step [1500/1549], Loss: 5.2321, Perplexity: 187.178\n",
      "Epoch [2/5], Step [0/1549], Loss: 5.5143, Perplexity: 248.212\n",
      "Epoch [2/5], Step [100/1549], Loss: 4.8018, Perplexity: 121.724\n",
      "Epoch [2/5], Step [200/1549], Loss: 4.8604, Perplexity: 129.072\n",
      "Epoch [2/5], Step [300/1549], Loss: 4.8850, Perplexity: 132.294\n",
      "Epoch [2/5], Step [400/1549], Loss: 4.8156, Perplexity: 123.419\n",
      "Epoch [2/5], Step [500/1549], Loss: 4.2986, Perplexity: 73.597\n",
      "Epoch [2/5], Step [600/1549], Loss: 4.6022, Perplexity: 99.708\n",
      "Epoch [2/5], Step [700/1549], Loss: 4.6148, Perplexity: 100.971\n",
      "Epoch [2/5], Step [800/1549], Loss: 4.5334, Perplexity: 93.073\n",
      "Epoch [2/5], Step [900/1549], Loss: 4.4177, Perplexity: 82.903\n",
      "Epoch [2/5], Step [1000/1549], Loss: 4.5003, Perplexity: 90.040\n",
      "Epoch [2/5], Step [1100/1549], Loss: 4.7395, Perplexity: 114.376\n",
      "Epoch [2/5], Step [1200/1549], Loss: 4.6529, Perplexity: 104.891\n",
      "Epoch [2/5], Step [1300/1549], Loss: 4.4225, Perplexity: 83.304\n",
      "Epoch [2/5], Step [1400/1549], Loss: 4.1775, Perplexity: 65.204\n",
      "Epoch [2/5], Step [1500/1549], Loss: 4.4723, Perplexity: 87.554\n",
      "Epoch [3/5], Step [0/1549], Loss: 4.6586, Perplexity: 105.492\n",
      "Epoch [3/5], Step [100/1549], Loss: 4.1702, Perplexity: 64.725\n",
      "Epoch [3/5], Step [200/1549], Loss: 4.2239, Perplexity: 68.302\n",
      "Epoch [3/5], Step [300/1549], Loss: 4.2206, Perplexity: 68.072\n",
      "Epoch [3/5], Step [400/1549], Loss: 4.1042, Perplexity: 60.597\n",
      "Epoch [3/5], Step [500/1549], Loss: 3.6042, Perplexity: 36.751\n",
      "Epoch [3/5], Step [600/1549], Loss: 4.0364, Perplexity: 56.620\n",
      "Epoch [3/5], Step [700/1549], Loss: 4.0101, Perplexity: 55.152\n",
      "Epoch [3/5], Step [800/1549], Loss: 3.9612, Perplexity: 52.522\n",
      "Epoch [3/5], Step [900/1549], Loss: 3.8032, Perplexity: 44.846\n",
      "Epoch [3/5], Step [1000/1549], Loss: 3.9016, Perplexity: 49.484\n",
      "Epoch [3/5], Step [1100/1549], Loss: 4.0868, Perplexity: 59.550\n",
      "Epoch [3/5], Step [1200/1549], Loss: 4.0488, Perplexity: 57.327\n",
      "Epoch [3/5], Step [1300/1549], Loss: 3.7428, Perplexity: 42.218\n",
      "Epoch [3/5], Step [1400/1549], Loss: 3.4713, Perplexity: 32.178\n",
      "Epoch [3/5], Step [1500/1549], Loss: 3.8378, Perplexity: 46.421\n",
      "Epoch [4/5], Step [0/1549], Loss: 3.8279, Perplexity: 45.966\n",
      "Epoch [4/5], Step [100/1549], Loss: 3.6297, Perplexity: 37.702\n",
      "Epoch [4/5], Step [200/1549], Loss: 3.7225, Perplexity: 41.368\n",
      "Epoch [4/5], Step [300/1549], Loss: 3.6463, Perplexity: 38.334\n",
      "Epoch [4/5], Step [400/1549], Loss: 3.5797, Perplexity: 35.864\n",
      "Epoch [4/5], Step [500/1549], Loss: 3.0872, Perplexity: 21.916\n",
      "Epoch [4/5], Step [600/1549], Loss: 3.5630, Perplexity: 35.267\n",
      "Epoch [4/5], Step [700/1549], Loss: 3.5234, Perplexity: 33.899\n",
      "Epoch [4/5], Step [800/1549], Loss: 3.4786, Perplexity: 32.413\n",
      "Epoch [4/5], Step [900/1549], Loss: 3.2635, Perplexity: 26.141\n",
      "Epoch [4/5], Step [1000/1549], Loss: 3.3585, Perplexity: 28.745\n",
      "Epoch [4/5], Step [1100/1549], Loss: 3.5683, Perplexity: 35.455\n",
      "Epoch [4/5], Step [1200/1549], Loss: 3.5329, Perplexity: 34.223\n",
      "Epoch [4/5], Step [1300/1549], Loss: 3.2065, Perplexity: 24.692\n",
      "Epoch [4/5], Step [1400/1549], Loss: 2.9073, Perplexity: 18.308\n",
      "Epoch [4/5], Step [1500/1549], Loss: 3.2895, Perplexity: 26.830\n",
      "Epoch [5/5], Step [0/1549], Loss: 3.2441, Perplexity: 25.640\n",
      "Epoch [5/5], Step [100/1549], Loss: 3.2038, Perplexity: 24.627\n",
      "Epoch [5/5], Step [200/1549], Loss: 3.3588, Perplexity: 28.754\n",
      "Epoch [5/5], Step [300/1549], Loss: 3.2025, Perplexity: 24.594\n",
      "Epoch [5/5], Step [400/1549], Loss: 3.1420, Perplexity: 23.150\n",
      "Epoch [5/5], Step [500/1549], Loss: 2.6720, Perplexity: 14.469\n",
      "Epoch [5/5], Step [600/1549], Loss: 3.1995, Perplexity: 24.520\n",
      "Epoch [5/5], Step [700/1549], Loss: 3.1201, Perplexity: 22.649\n",
      "Epoch [5/5], Step [800/1549], Loss: 3.1460, Perplexity: 23.242\n",
      "Epoch [5/5], Step [900/1549], Loss: 2.8151, Perplexity: 16.695\n",
      "Epoch [5/5], Step [1000/1549], Loss: 3.0033, Perplexity: 20.151\n",
      "Epoch [5/5], Step [1100/1549], Loss: 3.1089, Perplexity: 22.395\n",
      "Epoch [5/5], Step [1200/1549], Loss: 3.1492, Perplexity: 23.318\n",
      "Epoch [5/5], Step [1300/1549], Loss: 2.7896, Perplexity: 16.275\n",
      "Epoch [5/5], Step [1400/1549], Loss: 2.4877, Perplexity: 12.034\n",
      "Epoch [5/5], Step [1500/1549], Loss: 2.8676, Perplexity: 17.595\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(num_epochs):\n",
    "    states = (torch.zeros(num_layers, batch_size, hidden_size).to(device),\n",
    "              torch.zeros(num_layers, batch_size, hidden_size).to(device))\n",
    "\n",
    "    for i in range(0, ids.size(1) - seq_length, seq_length):\n",
    "        inputs = ids[:, i:i + seq_length].to(device)\n",
    "        targets = ids[:, (i + 1):(i + 1 + seq_length)].to(device)\n",
    "\n",
    "        states = detach(states)\n",
    "        outputs, states = model(inputs, states)\n",
    "        loss = criterion(outputs, targets.reshape(-1))\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        clip_grad_norm_(model.parameters(), 0.5)\n",
    "        optimizer.step()\n",
    "\n",
    "        step = (i + 1) // seq_length\n",
    "        if step % 500 == 0:\n",
    "            print(\n",
    "                \"Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.3f}\"\n",
    "                .format(epoch + 1, num_epochs, step, num_batchs, loss.item(),\n",
    "                        np.exp(loss.item())))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-16T09:06:13.756454Z",
     "start_time": "2020-04-16T09:06:13.460887Z"
    }
   },
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    state = (torch.zeros(num_layers, 1, hidden_size).to(device),\n",
    "             torch.zeros(num_layers, 1, hidden_size).to(device))\n",
    "    prob = torch.ones(vocab_size)\n",
    "    input = torch.multinomial(prob, num_samples=1).unsqueeze(1).to(device)\n",
    "\n",
    "    for i in range(num_samples):\n",
    "        output, state = model(input, state)\n",
    "        prob = output.exp()\n",
    "        word_id = torch.multinomial(prob, num_samples=1).item()\n",
    "\n",
    "        word = corpus.dictionary.idx2word[word_id]\n",
    "        word = '\\n' if word == '<eos>' else word + ' '\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {
    "height": "calc(100% - 180px)",
    "left": "10px",
    "top": "150px",
    "width": "192.4px"
   },
   "toc_section_display": true,
   "toc_window_display": true
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
